import torch import torch.nn as nn import math import os # --- ResidualBlock, Upsampler, and Generator classes remain the same --- class ResidualBlock(nn.Module): def __init__(self, num_features, kernel_size=3, bn=False, act=nn.ReLU(True), res_scale=1.0): super(ResidualBlock, self).__init__() padding = kernel_size // 2 m = [] m.append(nn.Conv2d(num_features, num_features, kernel_size, padding=padding)) if bn: m.append(nn.BatchNorm2d(num_features)) m.append(act) m.append(nn.Conv2d(num_features, num_features, kernel_size, padding=padding)) if bn: m.append(nn.BatchNorm2d(num_features)) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Module): def __init__(self, scale_factor, num_features, act=nn.ReLU(True)): super(Upsampler, self).__init__() m = [] m.append(nn.Conv2d(num_features, num_features * (scale_factor ** 2), kernel_size=3, padding=1)) m.append(nn.PixelShuffle(scale_factor)) if act: m.append(act) self.body = nn.Sequential(*m) def forward(self, x): return self.body(x) class Generator(nn.Module): def __init__(self, scale_factor=4, in_channels=3, out_channels=3, num_features=64, num_res_blocks=16, res_scale=1.0): super(Generator, self).__init__() self.scale_factor = scale_factor act = nn.ReLU(True) self.head = nn.Conv2d(in_channels, num_features, kernel_size=3, padding=1) res_blocks = [ResidualBlock(num_features, kernel_size=3, act=act, res_scale=res_scale) for _ in range(num_res_blocks)] res_blocks.append(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)) self.body = nn.Sequential(*res_blocks) m_tail = [] if (scale_factor & (scale_factor - 1)) == 0: for _ in range(int(math.log2(scale_factor))): m_tail.append(Upsampler(scale_factor=2, num_features=num_features, act=None)) elif scale_factor == 3: m_tail.append(Upsampler(scale_factor=3, num_features=num_features, act=None)) else: raise NotImplementedError(f"Scale factor {scale_factor} not directly supported by this simple upsampler.") self.tail = nn.Sequential(*m_tail) self.final_conv = nn.Conv2d(num_features, out_channels, kernel_size=3, padding=1) def forward(self, lr_img): x = self.head(lr_img) res = self.body(x) res += x x = self.tail(res) x = self.final_conv(x) return x # +++ NEW Discriminator Class +++ class Discriminator(nn.Module): """ Simple CNN Discriminator Network (PatchGAN style is common but this is simpler). Takes an image (real HR or generated SR) and outputs a single logit. """ def __init__(self, in_channels=3, num_features_start=64, num_blocks=4): super(Discriminator, self).__init__() # Initial block layers = [ nn.Conv2d(in_channels, num_features_start, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.2, inplace=True) ] current_features = num_features_start for i in range(num_blocks): stride = 1 if i % 2 == 0 else 2 # Downsample every other block next_features = current_features * 2 if stride == 2 else current_features layers.extend([ nn.Conv2d(current_features, next_features, kernel_size=3, stride=stride, padding=1), nn.BatchNorm2d(next_features), # BatchNorm is common in discriminators nn.LeakyReLU(0.2, inplace=True) ]) current_features = next_features self.features = nn.Sequential(*layers) # Classifier part - adjust input features based on final conv output size # We need to know the output size of the feature extractor to define the Linear layer. # Using AdaptiveAvgPool2d makes it independent of the input image size. self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.classifier = nn.Sequential( nn.Linear(current_features, 100), # Example intermediate size nn.LeakyReLU(0.2, inplace=True), nn.Linear(100, 1) # Output a single logit (no sigmoid here) ) def forward(self, img): """ Args: img (torch.Tensor): Input image tensor (B, C, H, W), either real HR or fake SR. Returns: torch.Tensor: Output logits (B, 1). Higher values -> more likely "real". """ batch_size = img.size(0) features = self.features(img) pooled = self.avgpool(features) # Flatten the output of avgpool for the linear layer pooled = pooled.view(batch_size, -1) output = self.classifier(pooled) return output # --- Main block for testing and saving --- if __name__ == '__main__': # --- Generator Test (as before) --- SCALE = 4 GEN_FEATURES = 64 GEN_RES_BLOCKS = 8 save_dir = "saved_models" os.makedirs(save_dir, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Dummy LR input for Generator gen_batch_size = 1 lr_height = 32 lr_width = 32 in_channels = 3 dummy_lr = torch.randn(gen_batch_size, in_channels, lr_height, lr_width).to(device) print(f"Dummy LR input shape (Generator): {dummy_lr.shape}") generator = Generator(scale_factor=SCALE, num_features=GEN_FEATURES, num_res_blocks=GEN_RES_BLOCKS).to(device) generator.eval() with torch.no_grad(): output_sr = generator(dummy_lr) print(f"Output SR shape (Generator): {output_sr.shape}") # ... (rest of generator verification and saving code remains here) ... print("\nGenerator definition test successful!") num_params_gen = sum(p.numel() for p in generator.parameters() if p.requires_grad) print(f"Generator - Number of trainable parameters: {num_params_gen:,}") # ... (Saving code as before) ... print("\n--- Testing Discriminator ---") # --- Discriminator Test --- DISC_FEATURES = 64 # Starting features for discriminator DISC_BLOCKS = 3 # Number of conv blocks in discriminator # Dummy HR/SR input for Discriminator (must match Generator's output size) disc_batch_size = 4 # Can be different from generator test batch size hr_height = output_sr.shape[2] # Use the calculated HR height hr_width = output_sr.shape[3] # Use the calculated HR width dummy_hr = torch.randn(disc_batch_size, in_channels, hr_height, hr_width).to(device) print(f"Dummy HR/SR input shape (Discriminator): {dummy_hr.shape}") # Instantiate the Discriminator discriminator = Discriminator(in_channels=in_channels, num_features_start=DISC_FEATURES, num_blocks=DISC_BLOCKS).to(device) discriminator.eval() # Set to evaluation mode for testing # print(discriminator) # Optional: Print structure # Perform a forward pass with torch.no_grad(): output_logits = discriminator(dummy_hr) print(f"Output Logits shape (Discriminator): {output_logits.shape}") # Verify output shape expected_disc_shape = (disc_batch_size, 1) assert output_logits.shape == expected_disc_shape, \ f"Discriminator output shape mismatch! Expected {expected_disc_shape}, got {output_logits.shape}" print("Discriminator definition test successful!") # Optional: Count parameters num_params_disc = sum(p.numel() for p in discriminator.parameters() if p.requires_grad) print(f"Discriminator - Number of trainable parameters: {num_params_disc:,}")