File size: 15,164 Bytes
fd5c0a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
# RUN python train.py --epochs 2 --batch_size 2 --subset 10 --num_workers 0 --cpu --patch_size 48
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import os
import argparse
from tqdm import tqdm
import time

# Import custom modules
from dataset import SRDataset # Make sure dataset.py is in the same directory
from models import Generator, Discriminator # Make sure models.py is in the same directory
from loss import PerceptualLoss # Make sure loss.py is in the same directory

def train(args):
    # --- 1. Setup ---
    device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
    print(f"Using device: {device}")

    # Create directories for saving models and potentially logs/outputs
    os.makedirs(args.save_dir, exist_ok=True)

    # --- 2. Data ---
    print("Loading dataset...")
    # Note: args.hr_dir and args.lr_dir are assumed to be valid paths by this point
    # due to checks in the __main__ block
    try:
        train_dataset = SRDataset(hr_dir=args.hr_dir, lr_dir=args.lr_dir, scale_factor=args.scale, patch_size_lr=args.patch_size)
    except FileNotFoundError as e:
        print(f"Error creating dataset: {e}")
        print("Please ensure the specified HR and LR directories contain correctly named image files.")
        exit(1)
    except Exception as e:
        print(f"An unexpected error occurred while creating the dataset: {e}")
        exit(1)


    # Use a smaller subset for initial testing on CPU if needed
    if args.subset > 0 and args.subset < len(train_dataset):
         print(f"Using a subset of {args.subset} images for training.")
         indices = torch.randperm(len(train_dataset))[:args.subset]
         train_dataset = torch.utils.data.Subset(train_dataset, indices)
    elif args.subset >= len(train_dataset) and len(train_dataset) > 0 :
        print(f"Subset size ({args.subset}) is >= dataset size ({len(train_dataset)}). Using full dataset.")


    if len(train_dataset) == 0:
        print(f"Error: Dataset is empty after attempting to load. Please check HR dir '{args.hr_dir}' and LR dir '{args.lr_dir}'")
        return

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers, # Set to 0 if you encounter issues on Windows/macOS
        pin_memory=True if device == 'cuda' else False # pin_memory only useful for GPU
    )
    print(f"Dataset loaded: {len(train_dataset)} training images.")
    print(f"Dataloader: {len(train_loader)} batches per epoch.")


    # --- 3. Models ---
    print("Initializing models...")
    generator = Generator(scale_factor=args.scale,
                          num_features=args.gen_features,
                          num_res_blocks=args.gen_blocks).to(device)

    discriminator = Discriminator(in_channels=3, # Assuming RGB input for discriminator
                                  num_features_start=args.disc_features,
                                  num_blocks=args.disc_blocks).to(device)

    print(f"Generator params: {sum(p.numel() for p in generator.parameters()):,}")
    print(f"Discriminator params: {sum(p.numel() for p in discriminator.parameters()):,}")

    # --- 4. Loss Functions ---
    print("Initializing loss functions...")
    # Content Loss (Pixel-wise) - L1 is common for SR
    content_loss_criterion = nn.L1Loss().to(device)

    # Adversarial Loss - Measures how well G fools D and D identifies fakes
    adversarial_loss_criterion = nn.BCEWithLogitsLoss().to(device) # More stable than BCELoss + Sigmoid

    # Perceptual Loss (VGG-based)
    try:
        perceptual_loss_criterion = PerceptualLoss(device=device, use_l1=True) # Using L1 feature distance
    except Exception as e:
        print(f"Error initializing Perceptual Loss (check VGG weights download/torchvision install): {e}")
        exit(1)


    # --- 5. Optimizers ---
    print("Initializing optimizers...")
    optimizer_g = optim.Adam(generator.parameters(), lr=args.lr_gen, betas=(0.9, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=args.lr_disc, betas=(0.9, 0.999))

    # --- Optional: Learning Rate Scheduler ---
    # Example: scheduler_g = optim.lr_scheduler.StepLR(optimizer_g, step_size=args.lr_decay_step, gamma=0.5)
    # Example: scheduler_d = optim.lr_scheduler.StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.5)

    # --- 6. Training Loop ---
    print("\n--- Starting Training ---")
    start_time = time.time()

    for epoch in range(1, args.epochs + 1):
        generator.train()   # Set generator to training mode
        discriminator.train() # Set discriminator to training mode
        epoch_loss_g = 0.0
        epoch_loss_d = 0.0
        epoch_start_time = time.time()

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs}", leave=True) # leave=True to keep bar after epoch

        for batch_idx, batch in enumerate(progress_bar):
            # Ensure batch is valid (dataset loader might return None on error in __getitem__)
            if batch is None:
                print(f"Warning: Skipping problematic batch at index {batch_idx}")
                continue

            try:
                lr_images = batch['lr'].to(device) # Low-resolution images
                hr_images = batch['hr'].to(device) # High-resolution (ground truth) images
            except KeyError as e:
                print(f"Error accessing batch data: {e}. Check SRDataset's __getitem__ return format.")
                continue # Skip this batch

            # Create labels for adversarial loss
            # Real labels = 1, Fake labels = 0
            # Add some noise or use soft labels (e.g., 0.9 instead of 1.0) can sometimes help stabilize GAN training
            real_labels = torch.ones((hr_images.size(0), 1)).to(device)
            fake_labels = torch.zeros((hr_images.size(0), 1)).to(device)

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_d.zero_grad()

            # Generate fake HR images
            # Use torch.no_grad() for generator forward pass when only training discriminator
            with torch.no_grad():
                 fake_sr_images = generator(lr_images) # No need to detach() if already in no_grad context

            # Loss for real images
            real_logits = discriminator(hr_images)
            loss_d_real = adversarial_loss_criterion(real_logits, real_labels)

            # Loss for fake images
            fake_logits = discriminator(fake_sr_images) # Use the generated fakes
            loss_d_fake = adversarial_loss_criterion(fake_logits, fake_labels)

            # Total discriminator loss
            loss_d = (loss_d_real + loss_d_fake) / 2

            # Backpropagate and update Discriminator
            loss_d.backward()
            # Optional: Gradient clipping for Discriminator (can help stability)
            # torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
            optimizer_d.step()


            # -----------------
            #  Train Generator
            # (Typically done less frequently than discriminator, e.g., every k steps,
            # but for simplicity here we do it every step)
            # -----------------
            optimizer_g.zero_grad()

            # Generate fake HR images (this time track gradients for G)
            generated_sr_images = generator(lr_images)

            # --- Calculate Generator Losses ---
            # 1. Content Loss (e.g., L1 distance between generated and real HR)
            loss_content = content_loss_criterion(generated_sr_images, hr_images)

            # 2. Perceptual Loss (VGG feature distance)
            loss_perceptual = perceptual_loss_criterion(generated_sr_images, hr_images)

            # 3. Adversarial Loss (how well G fools D)
            # We want the discriminator to output 'real' (1) for the generated images
            # Pass generated images through the discriminator (ensure D is not in no_grad context here)
            generated_logits = discriminator(generated_sr_images)
            loss_adversarial = adversarial_loss_criterion(generated_logits, real_labels) # Use real_labels!

            # --- Combine Generator Losses ---
            # Weights control the balance between pixel accuracy, perceptual quality, and realism
            loss_g = (args.lambda_content * loss_content +
                      args.lambda_percep * loss_perceptual +
                      args.lambda_adv * loss_adversarial)

            # Backpropagate and update Generator
            loss_g.backward()
             # Optional: Gradient clipping for Generator
            # torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
            optimizer_g.step()

            # --- Update running losses and progress bar ---
            epoch_loss_g += loss_g.item()
            epoch_loss_d += loss_d.item()
            progress_bar.set_postfix({
                'Loss G': f"{loss_g.item():.4f}",
                'Loss D': f"{loss_d.item():.4f}",
                # Optional: Show individual components of G loss
                # 'L_Cont': f"{loss_content.item():.4f}",
                # 'L_Perc': f"{loss_perceptual.item():.4f}",
                # 'L_Adv': f"{loss_adversarial.item():.4f}"
            })

        # --- End of Epoch ---
        avg_loss_g = epoch_loss_g / len(train_loader) if len(train_loader) > 0 else 0
        avg_loss_d = epoch_loss_d / len(train_loader) if len(train_loader) > 0 else 0
        epoch_time = time.time() - epoch_start_time

        # Optional: Update learning rate schedulers
        # scheduler_g.step()
        # scheduler_d.step()
        # current_lr_g = optimizer_g.param_groups[0]['lr']

        print(f"\nEpoch {epoch}/{args.epochs} | Time: {epoch_time:.2f}s | Avg Loss G: {avg_loss_g:.4f} | Avg Loss D: {avg_loss_d:.4f}")

        # --- Save Checkpoint ---
        if epoch % args.save_interval == 0 or epoch == args.epochs:
            gen_path = os.path.join(args.save_dir, f"generator_epoch_{epoch}.pth")
            disc_path = os.path.join(args.save_dir, f"discriminator_epoch_{epoch}.pth")
            try:
                torch.save(generator.state_dict(), gen_path)
                torch.save(discriminator.state_dict(), disc_path)
                print(f"Checkpoint saved for epoch {epoch} to '{args.save_dir}'")
            except Exception as e:
                print(f"Error saving checkpoint for epoch {epoch}: {e}")

    # --- End of Training ---
    total_time = time.time() - start_time
    print(f"\n--- Training Finished ---")
    print(f"Total time: {total_time // 3600:.0f}h {(total_time % 3600) // 60:.0f}m {total_time % 60:.2f}s")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train SRGAN Model')

    # --- Data Args ---
    parser.add_argument('--hr_dir', type=str,
                        default='./datasets/DIV2K/HR_extracted/DIV2K_train_HR',
                        help='Path to high-resolution training images')
    parser.add_argument('--lr_dir', type=str, default=None, # Default to None, will be auto-set
                        help='Path to low-resolution training images (auto-set if None)')
    parser.add_argument('--scale', type=int, default=4, help='Upscaling factor')
    parser.add_argument('--batch_size', type=int, default=16, help='Training batch size (reduce for CPU/low VRAM)')
    parser.add_argument('--subset', type=int, default=0, help='Use only N images for debugging (0 to use all)')
    parser.add_argument('--num_workers', type=int, default=0, help='Number of workers for DataLoader (set to 0 for Mac/Windows usually)')
    parser.add_argument('--patch_size', type=int, default=48, help='Size (height/width) of LR patches for training') # NEW ARGUMENT

    # --- Model Args ---
    parser.add_argument('--gen_features', type=int, default=64, help='Number of features in Generator')
    parser.add_argument('--gen_blocks', type=int, default=16, help='Number of residual blocks in Generator (reduce for faster training/less memory)')
    parser.add_argument('--disc_features', type=int, default=64, help='Number of starting features in Discriminator')
    parser.add_argument('--disc_blocks', type=int, default=3, help='Number of conv blocks in Discriminator')

    # --- Training Args ---
    parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
    parser.add_argument('--lr_gen', type=float, default=1e-4, help='Learning rate for Generator')
    parser.add_argument('--lr_disc', type=float, default=1e-4, help='Learning rate for Discriminator')
    parser.add_argument('--lambda_content', type=float, default=0.01, help='Weight for Content Loss (L1)') # SRGAN paper uses 1e-2 for L1/MSE when combined with VGG
    parser.add_argument('--lambda_percep', type=float, default=1.0, help='Weight for Perceptual Loss') # SRGAN paper uses 1.0
    parser.add_argument('--lambda_adv', type=float, default=0.001, help='Weight for Adversarial Loss') # SRGAN paper uses 1e-3

    # --- Other Args ---
    parser.add_argument('--save_dir', type=str, default='checkpoints', help='Directory to save model checkpoints')
    parser.add_argument('--save_interval', type=int, default=10, help='Save checkpoint every N epochs')
    parser.add_argument('--cpu', action='store_true', help='Force training on CPU')
    # parser.add_argument('--load_checkpoint', type=str, default=None, help='Path to checkpoint file to resume training') # Example for adding resume functionality

    args = parser.parse_args()

    # --- Set and Validate Directories ---
    # Auto-set LR directory based on scale IF it wasn't provided via command line
    if args.lr_dir is None:
        args.lr_dir = f'./datasets/DIV2K/DIV2K_train_LR_bicubic/X{args.scale}'
        print(f"LR directory not provided, automatically setting based on scale {args.scale} to: {args.lr_dir}")

    # Validate HR directory
    if not os.path.isdir(args.hr_dir):
         print(f"\nERROR: High-Resolution directory not found at '{args.hr_dir}'")
         print("Please ensure the directory exists or provide the correct path using --hr_dir.")
         exit(1) # Exit if the directory is invalid
    # Validate LR directory
    if not os.path.isdir(args.lr_dir):
         print(f"\nERROR: Low-Resolution directory not found at '{args.lr_dir}'")
         print(f"Please ensure the directory exists (check scale factor {args.scale}?) or provide the correct path using --lr_dir.")
         exit(1) # Exit if the directory is invalid

    print("\n--- Training Configuration ---")
    # Print configuration cleanly
    config_dict = vars(args)
    # Calculate terminal width for better formatting (optional)
    try:
        term_width = os.get_terminal_size().columns
    except OSError:
        term_width = 80 # Default if terminal size unavailable

    print("-" * term_width)
    for key, value in config_dict.items():
        print(f"{key:<25}: {value}") # Format for alignment
    print("-" * term_width)


    # Start the training process
    train(args)