OxO_Image-Repair / models.py
Gordon-H's picture
Upload 8 files
838846e verified
import torch
import torch.nn as nn
import math
import os
# --- ResidualBlock, Upsampler, and Generator classes remain the same ---
class ResidualBlock(nn.Module):
def __init__(self, num_features, kernel_size=3, bn=False, act=nn.ReLU(True), res_scale=1.0):
super(ResidualBlock, self).__init__()
padding = kernel_size // 2
m = []
m.append(nn.Conv2d(num_features, num_features, kernel_size, padding=padding))
if bn: m.append(nn.BatchNorm2d(num_features))
m.append(act)
m.append(nn.Conv2d(num_features, num_features, kernel_size, padding=padding))
if bn: m.append(nn.BatchNorm2d(num_features))
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x).mul(self.res_scale)
res += x
return res
class Upsampler(nn.Module):
def __init__(self, scale_factor, num_features, act=nn.ReLU(True)):
super(Upsampler, self).__init__()
m = []
m.append(nn.Conv2d(num_features, num_features * (scale_factor ** 2), kernel_size=3, padding=1))
m.append(nn.PixelShuffle(scale_factor))
if act: m.append(act)
self.body = nn.Sequential(*m)
def forward(self, x):
return self.body(x)
class Generator(nn.Module):
def __init__(self, scale_factor=4, in_channels=3, out_channels=3, num_features=64, num_res_blocks=16, res_scale=1.0):
super(Generator, self).__init__()
self.scale_factor = scale_factor
act = nn.ReLU(True)
self.head = nn.Conv2d(in_channels, num_features, kernel_size=3, padding=1)
res_blocks = [ResidualBlock(num_features, kernel_size=3, act=act, res_scale=res_scale) for _ in range(num_res_blocks)]
res_blocks.append(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1))
self.body = nn.Sequential(*res_blocks)
m_tail = []
if (scale_factor & (scale_factor - 1)) == 0:
for _ in range(int(math.log2(scale_factor))):
m_tail.append(Upsampler(scale_factor=2, num_features=num_features, act=None))
elif scale_factor == 3:
m_tail.append(Upsampler(scale_factor=3, num_features=num_features, act=None))
else:
raise NotImplementedError(f"Scale factor {scale_factor} not directly supported by this simple upsampler.")
self.tail = nn.Sequential(*m_tail)
self.final_conv = nn.Conv2d(num_features, out_channels, kernel_size=3, padding=1)
def forward(self, lr_img):
x = self.head(lr_img)
res = self.body(x)
res += x
x = self.tail(res)
x = self.final_conv(x)
return x
# +++ NEW Discriminator Class +++
class Discriminator(nn.Module):
"""
Simple CNN Discriminator Network (PatchGAN style is common but this is simpler).
Takes an image (real HR or generated SR) and outputs a single logit.
"""
def __init__(self, in_channels=3, num_features_start=64, num_blocks=4):
super(Discriminator, self).__init__()
# Initial block
layers = [
nn.Conv2d(in_channels, num_features_start, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True)
]
current_features = num_features_start
for i in range(num_blocks):
stride = 1 if i % 2 == 0 else 2 # Downsample every other block
next_features = current_features * 2 if stride == 2 else current_features
layers.extend([
nn.Conv2d(current_features, next_features, kernel_size=3, stride=stride, padding=1),
nn.BatchNorm2d(next_features), # BatchNorm is common in discriminators
nn.LeakyReLU(0.2, inplace=True)
])
current_features = next_features
self.features = nn.Sequential(*layers)
# Classifier part - adjust input features based on final conv output size
# We need to know the output size of the feature extractor to define the Linear layer.
# Using AdaptiveAvgPool2d makes it independent of the input image size.
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Sequential(
nn.Linear(current_features, 100), # Example intermediate size
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(100, 1) # Output a single logit (no sigmoid here)
)
def forward(self, img):
"""
Args:
img (torch.Tensor): Input image tensor (B, C, H, W), either real HR or fake SR.
Returns:
torch.Tensor: Output logits (B, 1). Higher values -> more likely "real".
"""
batch_size = img.size(0)
features = self.features(img)
pooled = self.avgpool(features)
# Flatten the output of avgpool for the linear layer
pooled = pooled.view(batch_size, -1)
output = self.classifier(pooled)
return output
# --- Main block for testing and saving ---
if __name__ == '__main__':
# --- Generator Test (as before) ---
SCALE = 4
GEN_FEATURES = 64
GEN_RES_BLOCKS = 8
save_dir = "saved_models"
os.makedirs(save_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Dummy LR input for Generator
gen_batch_size = 1
lr_height = 32
lr_width = 32
in_channels = 3
dummy_lr = torch.randn(gen_batch_size, in_channels, lr_height, lr_width).to(device)
print(f"Dummy LR input shape (Generator): {dummy_lr.shape}")
generator = Generator(scale_factor=SCALE, num_features=GEN_FEATURES, num_res_blocks=GEN_RES_BLOCKS).to(device)
generator.eval()
with torch.no_grad():
output_sr = generator(dummy_lr)
print(f"Output SR shape (Generator): {output_sr.shape}")
# ... (rest of generator verification and saving code remains here) ...
print("\nGenerator definition test successful!")
num_params_gen = sum(p.numel() for p in generator.parameters() if p.requires_grad)
print(f"Generator - Number of trainable parameters: {num_params_gen:,}")
# ... (Saving code as before) ...
print("\n--- Testing Discriminator ---")
# --- Discriminator Test ---
DISC_FEATURES = 64 # Starting features for discriminator
DISC_BLOCKS = 3 # Number of conv blocks in discriminator
# Dummy HR/SR input for Discriminator (must match Generator's output size)
disc_batch_size = 4 # Can be different from generator test batch size
hr_height = output_sr.shape[2] # Use the calculated HR height
hr_width = output_sr.shape[3] # Use the calculated HR width
dummy_hr = torch.randn(disc_batch_size, in_channels, hr_height, hr_width).to(device)
print(f"Dummy HR/SR input shape (Discriminator): {dummy_hr.shape}")
# Instantiate the Discriminator
discriminator = Discriminator(in_channels=in_channels,
num_features_start=DISC_FEATURES,
num_blocks=DISC_BLOCKS).to(device)
discriminator.eval() # Set to evaluation mode for testing
# print(discriminator) # Optional: Print structure
# Perform a forward pass
with torch.no_grad():
output_logits = discriminator(dummy_hr)
print(f"Output Logits shape (Discriminator): {output_logits.shape}")
# Verify output shape
expected_disc_shape = (disc_batch_size, 1)
assert output_logits.shape == expected_disc_shape, \
f"Discriminator output shape mismatch! Expected {expected_disc_shape}, got {output_logits.shape}"
print("Discriminator definition test successful!")
# Optional: Count parameters
num_params_disc = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
print(f"Discriminator - Number of trainable parameters: {num_params_disc:,}")