Spaces:
Sleeping
Sleeping
File size: 2,790 Bytes
7a5bb5d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import tensorflow as tf
import os
IMG_SIZE = (256, 256)
def decode_image(image_file):
"""Reads and decode an image to float32 [0, 1]."""
image = tf.io.read_file(image_file)
# Most general images dataset will be JPEG, but you can change to png if needed.
image = tf.image.decode_image(image, channels=3, expand_animations=False)
image = tf.image.resize(image, IMG_SIZE)
# Normalize between 0 and 1
image = tf.cast(image, tf.float32) / 255.0
return image
def decode_mask(mask_file):
"""Reads and decode a segmentation mask to binary [0, 1]."""
mask = tf.io.read_file(mask_file)
# Masks are typically PNG format
mask = tf.image.decode_image(mask, channels=1, expand_animations=False)
mask = tf.image.resize(mask, IMG_SIZE)
# Normalize and convert mask to binary (0 and 1)
mask = tf.cast(mask, tf.float32) / 255.0
mask = tf.math.round(mask)
return mask
def process_path(image_path, mask_path):
"""Loads image and mask from paths."""
image = decode_image(image_path)
mask = decode_mask(mask_path)
return image, mask
def augment(image, mask):
"""Applies random transformations for data augmentation."""
# Random flip left-right
if tf.random.uniform(()) > 0.5:
image = tf.image.flip_left_right(image)
mask = tf.image.flip_left_right(mask)
# Random flip up-down
if tf.random.uniform(()) > 0.5:
image = tf.image.flip_up_down(image)
mask = tf.image.flip_up_down(mask)
# Random brightness (only to the image, not the mask)
image = tf.image.random_brightness(image, max_delta=0.2)
# Clip values to be in [0, 1] after brightness modifications
image = tf.clip_by_value(image, 0.0, 1.0)
return image, mask
def get_dataset(image_paths, mask_paths, batch_size=16, is_train=True):
"""
Creates a tf.data.Dataset pipeline.
Args:
image_paths (list): List of file paths to images.
mask_paths (list): List of file paths to corresponding masks.
batch_size (int): Size of batches.
is_train (bool): If true, applies shuffling and augmentations.
"""
dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
if is_train:
# Shuffle heavily for training
dataset = dataset.shuffle(buffer_size=1000)
# Map the image reading function across the dataset
dataset = dataset.map(process_path, num_parallel_calls=tf.data.AUTOTUNE)
if is_train:
# Apply data augmentations
dataset = dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
# Batch and prefetch for performance
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
|