File size: 2,790 Bytes
7a5bb5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import tensorflow as tf
import os

IMG_SIZE = (256, 256)

def decode_image(image_file):
    """Reads and decode an image to float32 [0, 1]."""
    image = tf.io.read_file(image_file)
    # Most general images dataset will be JPEG, but you can change to png if needed.
    image = tf.image.decode_image(image, channels=3, expand_animations=False)
    image = tf.image.resize(image, IMG_SIZE)
    # Normalize between 0 and 1
    image = tf.cast(image, tf.float32) / 255.0
    return image

def decode_mask(mask_file):
    """Reads and decode a segmentation mask to binary [0, 1]."""
    mask = tf.io.read_file(mask_file)
    # Masks are typically PNG format
    mask = tf.image.decode_image(mask, channels=1, expand_animations=False)
    mask = tf.image.resize(mask, IMG_SIZE)
    
    # Normalize and convert mask to binary (0 and 1)
    mask = tf.cast(mask, tf.float32) / 255.0
    mask = tf.math.round(mask) 
    return mask

def process_path(image_path, mask_path):
    """Loads image and mask from paths."""
    image = decode_image(image_path)
    mask = decode_mask(mask_path)
    return image, mask

def augment(image, mask):
    """Applies random transformations for data augmentation."""
    # Random flip left-right
    if tf.random.uniform(()) > 0.5:
        image = tf.image.flip_left_right(image)
        mask = tf.image.flip_left_right(mask)
        
    # Random flip up-down
    if tf.random.uniform(()) > 0.5:
        image = tf.image.flip_up_down(image)
        mask = tf.image.flip_up_down(mask)
        
    # Random brightness (only to the image, not the mask)
    image = tf.image.random_brightness(image, max_delta=0.2)
    
    # Clip values to be in [0, 1] after brightness modifications
    image = tf.clip_by_value(image, 0.0, 1.0)
    
    return image, mask

def get_dataset(image_paths, mask_paths, batch_size=16, is_train=True):
    """
    Creates a tf.data.Dataset pipeline.
    
    Args:
        image_paths (list): List of file paths to images.
        mask_paths (list): List of file paths to corresponding masks.
        batch_size (int): Size of batches.
        is_train (bool): If true, applies shuffling and augmentations.
    """
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
    
    if is_train:
        # Shuffle heavily for training
        dataset = dataset.shuffle(buffer_size=1000)
    
    # Map the image reading function across the dataset
    dataset = dataset.map(process_path, num_parallel_calls=tf.data.AUTOTUNE)
    
    if is_train:
        # Apply data augmentations
        dataset = dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
        
    # Batch and prefetch for performance
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset