import tensorflow as tf import os IMG_SIZE = (256, 256) def decode_image(image_file): """Reads and decode an image to float32 [0, 1].""" image = tf.io.read_file(image_file) # Most general images dataset will be JPEG, but you can change to png if needed. image = tf.image.decode_image(image, channels=3, expand_animations=False) image = tf.image.resize(image, IMG_SIZE) # Normalize between 0 and 1 image = tf.cast(image, tf.float32) / 255.0 return image def decode_mask(mask_file): """Reads and decode a segmentation mask to binary [0, 1].""" mask = tf.io.read_file(mask_file) # Masks are typically PNG format mask = tf.image.decode_image(mask, channels=1, expand_animations=False) mask = tf.image.resize(mask, IMG_SIZE) # Normalize and convert mask to binary (0 and 1) mask = tf.cast(mask, tf.float32) / 255.0 mask = tf.math.round(mask) return mask def process_path(image_path, mask_path): """Loads image and mask from paths.""" image = decode_image(image_path) mask = decode_mask(mask_path) return image, mask def augment(image, mask): """Applies random transformations for data augmentation.""" # Random flip left-right if tf.random.uniform(()) > 0.5: image = tf.image.flip_left_right(image) mask = tf.image.flip_left_right(mask) # Random flip up-down if tf.random.uniform(()) > 0.5: image = tf.image.flip_up_down(image) mask = tf.image.flip_up_down(mask) # Random brightness (only to the image, not the mask) image = tf.image.random_brightness(image, max_delta=0.2) # Clip values to be in [0, 1] after brightness modifications image = tf.clip_by_value(image, 0.0, 1.0) return image, mask def get_dataset(image_paths, mask_paths, batch_size=16, is_train=True): """ Creates a tf.data.Dataset pipeline. Args: image_paths (list): List of file paths to images. mask_paths (list): List of file paths to corresponding masks. batch_size (int): Size of batches. is_train (bool): If true, applies shuffling and augmentations. """ dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths)) if is_train: # Shuffle heavily for training dataset = dataset.shuffle(buffer_size=1000) # Map the image reading function across the dataset dataset = dataset.map(process_path, num_parallel_calls=tf.data.AUTOTUNE) if is_train: # Apply data augmentations dataset = dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE) # Batch and prefetch for performance dataset = dataset.batch(batch_size) dataset = dataset.prefetch(tf.data.AUTOTUNE) return dataset