nithinasadhu commited on
Commit
a319d6f
·
verified ·
1 Parent(s): ae76b1b

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +69 -0
utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================
2
+ # utils.py | Helper functions for Super-Resolution
3
+ # ==============================================
4
+
5
+ %%writefile utils.py
6
+ import tensorflow as tf
7
+ import tensorflow_datasets as tfds
8
+
9
+ # --- Define Image Dimensions ---
10
+ HR_SIZE = (128, 128)
11
+ LR_SIZE = (HR_SIZE[0] // 2, HR_SIZE[1] // 2) # 64x64
12
+
13
+
14
+ def preprocess_image(data):
15
+ """
16
+ Normalizes and resizes images for the dataset.
17
+ The model's input will be a bicubic-upscaled version of the LR image.
18
+ """
19
+ # Normalize pixel values to [0, 1]
20
+ hr_image = tf.cast(data['hr'], tf.float32) / 255.0
21
+ lr_image = tf.cast(data['lr'], tf.float32) / 255.0
22
+
23
+ # Resize to target dimensions
24
+ hr_image = tf.image.resize(hr_image, HR_SIZE, method='bicubic')
25
+ lr_image = tf.image.resize(lr_image, LR_SIZE, method='bicubic')
26
+
27
+ # Create model input by upscaling the low-res image
28
+ model_input_image = tf.image.resize(lr_image, HR_SIZE, method='bicubic')
29
+
30
+ return model_input_image, hr_image
31
+
32
+
33
+ def load_div2k_data(batch_size=16):
34
+ """
35
+ Loads the DIV2K dataset and creates an efficient tf.data pipeline.
36
+ Uses the 'bicubic_x2' version for 2x super-resolution.
37
+ """
38
+ print("Loading and preparing DIV2K dataset...")
39
+
40
+ # Load the dataset using TensorFlow Datasets
41
+ (train_ds, valid_ds), ds_info = tfds.load(
42
+ 'div2k/bicubic_x2',
43
+ split=['train', 'validation'],
44
+ as_supervised=False, # We provide our own preprocessing
45
+ with_info=True
46
+ )
47
+
48
+ # --- Create the Training Pipeline ---
49
+ train_dataset = (
50
+ train_ds
51
+ .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
52
+ .cache() # Cache for performance
53
+ .shuffle(buffer_size=100)
54
+ .batch(batch_size)
55
+ .repeat() # Repeat for multiple epochs
56
+ .prefetch(buffer_size=tf.data.AUTOTUNE)
57
+ )
58
+
59
+ # --- Create the Validation Pipeline ---
60
+ validation_dataset = (
61
+ valid_ds
62
+ .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
63
+ .batch(batch_size)
64
+ .cache()
65
+ .prefetch(buffer_size=tf.data.AUTOTUNE)
66
+ )
67
+
68
+ print("✅ Dataset pipelines created successfully.")
69
+ return train_dataset, validation_dataset, ds_info