Spaces:
Sleeping
Sleeping
File size: 2,331 Bytes
a319d6f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | # ==============================================
# utils.py | Helper functions for Super-Resolution
# ==============================================
%%writefile utils.py
import tensorflow as tf
import tensorflow_datasets as tfds
# --- Define Image Dimensions ---
HR_SIZE = (128, 128)
LR_SIZE = (HR_SIZE[0] // 2, HR_SIZE[1] // 2) # 64x64
def preprocess_image(data):
"""
Normalizes and resizes images for the dataset.
The model's input will be a bicubic-upscaled version of the LR image.
"""
# Normalize pixel values to [0, 1]
hr_image = tf.cast(data['hr'], tf.float32) / 255.0
lr_image = tf.cast(data['lr'], tf.float32) / 255.0
# Resize to target dimensions
hr_image = tf.image.resize(hr_image, HR_SIZE, method='bicubic')
lr_image = tf.image.resize(lr_image, LR_SIZE, method='bicubic')
# Create model input by upscaling the low-res image
model_input_image = tf.image.resize(lr_image, HR_SIZE, method='bicubic')
return model_input_image, hr_image
def load_div2k_data(batch_size=16):
"""
Loads the DIV2K dataset and creates an efficient tf.data pipeline.
Uses the 'bicubic_x2' version for 2x super-resolution.
"""
print("Loading and preparing DIV2K dataset...")
# Load the dataset using TensorFlow Datasets
(train_ds, valid_ds), ds_info = tfds.load(
'div2k/bicubic_x2',
split=['train', 'validation'],
as_supervised=False, # We provide our own preprocessing
with_info=True
)
# --- Create the Training Pipeline ---
train_dataset = (
train_ds
.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
.cache() # Cache for performance
.shuffle(buffer_size=100)
.batch(batch_size)
.repeat() # Repeat for multiple epochs
.prefetch(buffer_size=tf.data.AUTOTUNE)
)
# --- Create the Validation Pipeline ---
validation_dataset = (
valid_ds
.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
.batch(batch_size)
.cache()
.prefetch(buffer_size=tf.data.AUTOTUNE)
)
print("✅ Dataset pipelines created successfully.")
return train_dataset, validation_dataset, ds_info
|