Spaces:
Sleeping
Sleeping
Upload utils.py
Browse files
utils.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ==============================================
|
| 2 |
+
# utils.py | Helper functions for Super-Resolution
|
| 3 |
+
# ==============================================
|
| 4 |
+
|
| 5 |
+
%%writefile utils.py
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
import tensorflow_datasets as tfds
|
| 8 |
+
|
| 9 |
+
# --- Define Image Dimensions ---
|
| 10 |
+
HR_SIZE = (128, 128)
|
| 11 |
+
LR_SIZE = (HR_SIZE[0] // 2, HR_SIZE[1] // 2) # 64x64
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def preprocess_image(data):
|
| 15 |
+
"""
|
| 16 |
+
Normalizes and resizes images for the dataset.
|
| 17 |
+
The model's input will be a bicubic-upscaled version of the LR image.
|
| 18 |
+
"""
|
| 19 |
+
# Normalize pixel values to [0, 1]
|
| 20 |
+
hr_image = tf.cast(data['hr'], tf.float32) / 255.0
|
| 21 |
+
lr_image = tf.cast(data['lr'], tf.float32) / 255.0
|
| 22 |
+
|
| 23 |
+
# Resize to target dimensions
|
| 24 |
+
hr_image = tf.image.resize(hr_image, HR_SIZE, method='bicubic')
|
| 25 |
+
lr_image = tf.image.resize(lr_image, LR_SIZE, method='bicubic')
|
| 26 |
+
|
| 27 |
+
# Create model input by upscaling the low-res image
|
| 28 |
+
model_input_image = tf.image.resize(lr_image, HR_SIZE, method='bicubic')
|
| 29 |
+
|
| 30 |
+
return model_input_image, hr_image
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_div2k_data(batch_size=16):
|
| 34 |
+
"""
|
| 35 |
+
Loads the DIV2K dataset and creates an efficient tf.data pipeline.
|
| 36 |
+
Uses the 'bicubic_x2' version for 2x super-resolution.
|
| 37 |
+
"""
|
| 38 |
+
print("Loading and preparing DIV2K dataset...")
|
| 39 |
+
|
| 40 |
+
# Load the dataset using TensorFlow Datasets
|
| 41 |
+
(train_ds, valid_ds), ds_info = tfds.load(
|
| 42 |
+
'div2k/bicubic_x2',
|
| 43 |
+
split=['train', 'validation'],
|
| 44 |
+
as_supervised=False, # We provide our own preprocessing
|
| 45 |
+
with_info=True
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# --- Create the Training Pipeline ---
|
| 49 |
+
train_dataset = (
|
| 50 |
+
train_ds
|
| 51 |
+
.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
|
| 52 |
+
.cache() # Cache for performance
|
| 53 |
+
.shuffle(buffer_size=100)
|
| 54 |
+
.batch(batch_size)
|
| 55 |
+
.repeat() # Repeat for multiple epochs
|
| 56 |
+
.prefetch(buffer_size=tf.data.AUTOTUNE)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# --- Create the Validation Pipeline ---
|
| 60 |
+
validation_dataset = (
|
| 61 |
+
valid_ds
|
| 62 |
+
.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
|
| 63 |
+
.batch(batch_size)
|
| 64 |
+
.cache()
|
| 65 |
+
.prefetch(buffer_size=tf.data.AUTOTUNE)
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
print("✅ Dataset pipelines created successfully.")
|
| 69 |
+
return train_dataset, validation_dataset, ds_info
|