Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import math | |
| import os | |
| # --- ResidualBlock, Upsampler, and Generator classes remain the same --- | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, num_features, kernel_size=3, bn=False, act=nn.ReLU(True), res_scale=1.0): | |
| super(ResidualBlock, self).__init__() | |
| padding = kernel_size // 2 | |
| m = [] | |
| m.append(nn.Conv2d(num_features, num_features, kernel_size, padding=padding)) | |
| if bn: m.append(nn.BatchNorm2d(num_features)) | |
| m.append(act) | |
| m.append(nn.Conv2d(num_features, num_features, kernel_size, padding=padding)) | |
| if bn: m.append(nn.BatchNorm2d(num_features)) | |
| self.body = nn.Sequential(*m) | |
| self.res_scale = res_scale | |
| def forward(self, x): | |
| res = self.body(x).mul(self.res_scale) | |
| res += x | |
| return res | |
| class Upsampler(nn.Module): | |
| def __init__(self, scale_factor, num_features, act=nn.ReLU(True)): | |
| super(Upsampler, self).__init__() | |
| m = [] | |
| m.append(nn.Conv2d(num_features, num_features * (scale_factor ** 2), kernel_size=3, padding=1)) | |
| m.append(nn.PixelShuffle(scale_factor)) | |
| if act: m.append(act) | |
| self.body = nn.Sequential(*m) | |
| def forward(self, x): | |
| return self.body(x) | |
| class Generator(nn.Module): | |
| def __init__(self, scale_factor=4, in_channels=3, out_channels=3, num_features=64, num_res_blocks=16, res_scale=1.0): | |
| super(Generator, self).__init__() | |
| self.scale_factor = scale_factor | |
| act = nn.ReLU(True) | |
| self.head = nn.Conv2d(in_channels, num_features, kernel_size=3, padding=1) | |
| res_blocks = [ResidualBlock(num_features, kernel_size=3, act=act, res_scale=res_scale) for _ in range(num_res_blocks)] | |
| res_blocks.append(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)) | |
| self.body = nn.Sequential(*res_blocks) | |
| m_tail = [] | |
| if (scale_factor & (scale_factor - 1)) == 0: | |
| for _ in range(int(math.log2(scale_factor))): | |
| m_tail.append(Upsampler(scale_factor=2, num_features=num_features, act=None)) | |
| elif scale_factor == 3: | |
| m_tail.append(Upsampler(scale_factor=3, num_features=num_features, act=None)) | |
| else: | |
| raise NotImplementedError(f"Scale factor {scale_factor} not directly supported by this simple upsampler.") | |
| self.tail = nn.Sequential(*m_tail) | |
| self.final_conv = nn.Conv2d(num_features, out_channels, kernel_size=3, padding=1) | |
| def forward(self, lr_img): | |
| x = self.head(lr_img) | |
| res = self.body(x) | |
| res += x | |
| x = self.tail(res) | |
| x = self.final_conv(x) | |
| return x | |
| # +++ NEW Discriminator Class +++ | |
| class Discriminator(nn.Module): | |
| """ | |
| Simple CNN Discriminator Network (PatchGAN style is common but this is simpler). | |
| Takes an image (real HR or generated SR) and outputs a single logit. | |
| """ | |
| def __init__(self, in_channels=3, num_features_start=64, num_blocks=4): | |
| super(Discriminator, self).__init__() | |
| # Initial block | |
| layers = [ | |
| nn.Conv2d(in_channels, num_features_start, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True) | |
| ] | |
| current_features = num_features_start | |
| for i in range(num_blocks): | |
| stride = 1 if i % 2 == 0 else 2 # Downsample every other block | |
| next_features = current_features * 2 if stride == 2 else current_features | |
| layers.extend([ | |
| nn.Conv2d(current_features, next_features, kernel_size=3, stride=stride, padding=1), | |
| nn.BatchNorm2d(next_features), # BatchNorm is common in discriminators | |
| nn.LeakyReLU(0.2, inplace=True) | |
| ]) | |
| current_features = next_features | |
| self.features = nn.Sequential(*layers) | |
| # Classifier part - adjust input features based on final conv output size | |
| # We need to know the output size of the feature extractor to define the Linear layer. | |
| # Using AdaptiveAvgPool2d makes it independent of the input image size. | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(current_features, 100), # Example intermediate size | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Linear(100, 1) # Output a single logit (no sigmoid here) | |
| ) | |
| def forward(self, img): | |
| """ | |
| Args: | |
| img (torch.Tensor): Input image tensor (B, C, H, W), either real HR or fake SR. | |
| Returns: | |
| torch.Tensor: Output logits (B, 1). Higher values -> more likely "real". | |
| """ | |
| batch_size = img.size(0) | |
| features = self.features(img) | |
| pooled = self.avgpool(features) | |
| # Flatten the output of avgpool for the linear layer | |
| pooled = pooled.view(batch_size, -1) | |
| output = self.classifier(pooled) | |
| return output | |
| # --- Main block for testing and saving --- | |
| if __name__ == '__main__': | |
| # --- Generator Test (as before) --- | |
| SCALE = 4 | |
| GEN_FEATURES = 64 | |
| GEN_RES_BLOCKS = 8 | |
| save_dir = "saved_models" | |
| os.makedirs(save_dir, exist_ok=True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Dummy LR input for Generator | |
| gen_batch_size = 1 | |
| lr_height = 32 | |
| lr_width = 32 | |
| in_channels = 3 | |
| dummy_lr = torch.randn(gen_batch_size, in_channels, lr_height, lr_width).to(device) | |
| print(f"Dummy LR input shape (Generator): {dummy_lr.shape}") | |
| generator = Generator(scale_factor=SCALE, num_features=GEN_FEATURES, num_res_blocks=GEN_RES_BLOCKS).to(device) | |
| generator.eval() | |
| with torch.no_grad(): | |
| output_sr = generator(dummy_lr) | |
| print(f"Output SR shape (Generator): {output_sr.shape}") | |
| # ... (rest of generator verification and saving code remains here) ... | |
| print("\nGenerator definition test successful!") | |
| num_params_gen = sum(p.numel() for p in generator.parameters() if p.requires_grad) | |
| print(f"Generator - Number of trainable parameters: {num_params_gen:,}") | |
| # ... (Saving code as before) ... | |
| print("\n--- Testing Discriminator ---") | |
| # --- Discriminator Test --- | |
| DISC_FEATURES = 64 # Starting features for discriminator | |
| DISC_BLOCKS = 3 # Number of conv blocks in discriminator | |
| # Dummy HR/SR input for Discriminator (must match Generator's output size) | |
| disc_batch_size = 4 # Can be different from generator test batch size | |
| hr_height = output_sr.shape[2] # Use the calculated HR height | |
| hr_width = output_sr.shape[3] # Use the calculated HR width | |
| dummy_hr = torch.randn(disc_batch_size, in_channels, hr_height, hr_width).to(device) | |
| print(f"Dummy HR/SR input shape (Discriminator): {dummy_hr.shape}") | |
| # Instantiate the Discriminator | |
| discriminator = Discriminator(in_channels=in_channels, | |
| num_features_start=DISC_FEATURES, | |
| num_blocks=DISC_BLOCKS).to(device) | |
| discriminator.eval() # Set to evaluation mode for testing | |
| # print(discriminator) # Optional: Print structure | |
| # Perform a forward pass | |
| with torch.no_grad(): | |
| output_logits = discriminator(dummy_hr) | |
| print(f"Output Logits shape (Discriminator): {output_logits.shape}") | |
| # Verify output shape | |
| expected_disc_shape = (disc_batch_size, 1) | |
| assert output_logits.shape == expected_disc_shape, \ | |
| f"Discriminator output shape mismatch! Expected {expected_disc_shape}, got {output_logits.shape}" | |
| print("Discriminator definition test successful!") | |
| # Optional: Count parameters | |
| num_params_disc = sum(p.numel() for p in discriminator.parameters() if p.requires_grad) | |
| print(f"Discriminator - Number of trainable parameters: {num_params_disc:,}") |