File size: 2,220 Bytes
ae76b1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#train.py

import os
import tensorflow as tf
import matplotlib.pyplot as plt

# Import the new data loader and the existing model builder
from utils import load_div2k_data
from model import build_enhanced_model, psnr

# --- 1. Training Configuration ---
BATCH_SIZE = 16 # Smaller batch size for larger images to fit in GPU memory
EPOCHS = 30     # Fewer epochs, as each one takes longer. Increase for higher quality.

# --- 2. Load the Dataset ---
train_ds, valid_ds, ds_info = load_div2k_data(batch_size=BATCH_SIZE)

# Calculate steps per epoch
steps_per_epoch = ds_info.splits['train'].num_examples // BATCH_SIZE
validation_steps = ds_info.splits['validation'].num_examples // BATCH_SIZE

# --- 3. Build the Model for 128x128 Input ---
INPUT_SHAPE = (128, 128, 3)
model = build_enhanced_model(input_shape=INPUT_SHAPE)
model.summary()

# --- 4. Train the Model ---
print("\nStarting model training on 128x128 images...")
history = model.fit(
    train_ds,
    epochs=EPOCHS,
    steps_per_epoch=steps_per_epoch,
    validation_data=valid_ds,
    validation_steps=validation_steps
)
print("Training finished.")

# --- 5. Save the New Model ---
if not os.path.exists('models'):
    os.makedirs('models')

model_path = 'models/sr_128_model.h5'
model.save(model_path)
print(f"✅ Model for 128x128 saved to {model_path}")

# --- 6. Visualize a Test Result ---
print("\nVisualizing a sample prediction...")
# Get one batch from the validation dataset to visualize
for lr_batch, hr_batch in valid_ds.take(1):
    # Take the first image from the batch
    lr_image = lr_batch[0]
    hr_image = hr_batch[0]

    # Predict
    pred_image = model.predict(tf.expand_dims(lr_image, axis=0))[0]

    # Plot
    plt.figure(figsize=(15, 6))
    plt.subplot(1, 3, 1)
    plt.imshow(lr_image)
    plt.title('Low-Res Input (128x128 Upscaled)')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(tf.clip_by_value(pred_image, 0, 1)) # Clip values to [0,1] for display
    plt.title('AI Super-Resolved Output')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(hr_image)
    plt.title('Original High-Resolution')
    plt.axis('off')
    plt.show()