File size: 2,331 Bytes
a319d6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# ==============================================
# utils.py  |  Helper functions for Super-Resolution
# ==============================================

%%writefile utils.py
import tensorflow as tf
import tensorflow_datasets as tfds

# --- Define Image Dimensions ---
HR_SIZE = (128, 128)
LR_SIZE = (HR_SIZE[0] // 2, HR_SIZE[1] // 2)  # 64x64


def preprocess_image(data):
    """

    Normalizes and resizes images for the dataset.

    The model's input will be a bicubic-upscaled version of the LR image.

    """
    # Normalize pixel values to [0, 1]
    hr_image = tf.cast(data['hr'], tf.float32) / 255.0
    lr_image = tf.cast(data['lr'], tf.float32) / 255.0

    # Resize to target dimensions
    hr_image = tf.image.resize(hr_image, HR_SIZE, method='bicubic')
    lr_image = tf.image.resize(lr_image, LR_SIZE, method='bicubic')

    # Create model input by upscaling the low-res image
    model_input_image = tf.image.resize(lr_image, HR_SIZE, method='bicubic')

    return model_input_image, hr_image


def load_div2k_data(batch_size=16):
    """

    Loads the DIV2K dataset and creates an efficient tf.data pipeline.

    Uses the 'bicubic_x2' version for 2x super-resolution.

    """
    print("Loading and preparing DIV2K dataset...")

    # Load the dataset using TensorFlow Datasets
    (train_ds, valid_ds), ds_info = tfds.load(
        'div2k/bicubic_x2',
        split=['train', 'validation'],
        as_supervised=False,  # We provide our own preprocessing
        with_info=True
    )

    # --- Create the Training Pipeline ---
    train_dataset = (
        train_ds
        .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
        .cache()                     # Cache for performance
        .shuffle(buffer_size=100)
        .batch(batch_size)
        .repeat()                    # Repeat for multiple epochs
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )

    # --- Create the Validation Pipeline ---
    validation_dataset = (
        valid_ds
        .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(batch_size)
        .cache()
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )

    print("✅ Dataset pipelines created successfully.")
    return train_dataset, validation_dataset, ds_info