Spaces:
Sleeping
Sleeping
| import tensorflow as tf | |
| import os | |
| IMG_SIZE = (256, 256) | |
| def decode_image(image_file): | |
| """Reads and decode an image to float32 [0, 1].""" | |
| image = tf.io.read_file(image_file) | |
| # Most general images dataset will be JPEG, but you can change to png if needed. | |
| image = tf.image.decode_image(image, channels=3, expand_animations=False) | |
| image = tf.image.resize(image, IMG_SIZE) | |
| # Normalize between 0 and 1 | |
| image = tf.cast(image, tf.float32) / 255.0 | |
| return image | |
| def decode_mask(mask_file): | |
| """Reads and decode a segmentation mask to binary [0, 1].""" | |
| mask = tf.io.read_file(mask_file) | |
| # Masks are typically PNG format | |
| mask = tf.image.decode_image(mask, channels=1, expand_animations=False) | |
| mask = tf.image.resize(mask, IMG_SIZE) | |
| # Normalize and convert mask to binary (0 and 1) | |
| mask = tf.cast(mask, tf.float32) / 255.0 | |
| mask = tf.math.round(mask) | |
| return mask | |
| def process_path(image_path, mask_path): | |
| """Loads image and mask from paths.""" | |
| image = decode_image(image_path) | |
| mask = decode_mask(mask_path) | |
| return image, mask | |
| def augment(image, mask): | |
| """Applies random transformations for data augmentation.""" | |
| # Random flip left-right | |
| if tf.random.uniform(()) > 0.5: | |
| image = tf.image.flip_left_right(image) | |
| mask = tf.image.flip_left_right(mask) | |
| # Random flip up-down | |
| if tf.random.uniform(()) > 0.5: | |
| image = tf.image.flip_up_down(image) | |
| mask = tf.image.flip_up_down(mask) | |
| # Random brightness (only to the image, not the mask) | |
| image = tf.image.random_brightness(image, max_delta=0.2) | |
| # Clip values to be in [0, 1] after brightness modifications | |
| image = tf.clip_by_value(image, 0.0, 1.0) | |
| return image, mask | |
| def get_dataset(image_paths, mask_paths, batch_size=16, is_train=True): | |
| """ | |
| Creates a tf.data.Dataset pipeline. | |
| Args: | |
| image_paths (list): List of file paths to images. | |
| mask_paths (list): List of file paths to corresponding masks. | |
| batch_size (int): Size of batches. | |
| is_train (bool): If true, applies shuffling and augmentations. | |
| """ | |
| dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths)) | |
| if is_train: | |
| # Shuffle heavily for training | |
| dataset = dataset.shuffle(buffer_size=1000) | |
| # Map the image reading function across the dataset | |
| dataset = dataset.map(process_path, num_parallel_calls=tf.data.AUTOTUNE) | |
| if is_train: | |
| # Apply data augmentations | |
| dataset = dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE) | |
| # Batch and prefetch for performance | |
| dataset = dataset.batch(batch_size) | |
| dataset = dataset.prefetch(tf.data.AUTOTUNE) | |
| return dataset | |