import torch import torch.nn as nn # --------------------------------------------------------------------------- # Original Denoising Autoencoder (96×96 → 96×96) # --------------------------------------------------------------------------- class Encoder(nn.Module): """Convolutional encoder: 3×96×96 → 256×12×12""" def __init__(self): super().__init__() self.net = nn.Sequential( # 3×96×96 → 32×96×96 nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), # 32×96×96 → 64×48×48 nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), # 64×48×48 → 128×24×24 nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), # 128×24×24 → 256×12×12 nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ) def forward(self, x): return self.net(x) class Decoder(nn.Module): """Convolutional decoder: 256×12×12 → 3×96×96""" def __init__(self): super().__init__() self.net = nn.Sequential( # 256×12×12 → 128×24×24 nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), # 128×24×24 → 64×48×48 nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), # 64×48×48 → 32×96×96 nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), # 32×96×96 → 3×96×96 nn.Conv2d(32, 3, kernel_size=3, padding=1), nn.Sigmoid(), ) def forward(self, x): return self.net(x) class DenoisingAutoencoder(nn.Module): """Full denoising autoencoder: noisy image → clean image.""" def __init__(self): super().__init__() self.encoder = Encoder() self.decoder = Decoder() def forward(self, x): z = self.encoder(x) return self.decoder(z) def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) # --------------------------------------------------------------------------- # Super-Resolution Autoencoder (noisy 48×48 → clean 96×96) # --------------------------------------------------------------------------- class SREncoder(nn.Module): """Convolutional encoder: 3×48×48 → 256×6×6""" def __init__(self): super().__init__() self.net = nn.Sequential( # 3×48×48 → 32×48×48 nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), # 32×48×48 → 64×24×24 nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), # 64×24×24 → 128×12×12 nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), # 128×12×12 → 256×6×6 nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ) def forward(self, x): return self.net(x) class SRDecoder(nn.Module): """Convolutional decoder: 256×6×6 → 3×96×96 (2× upscale)""" def __init__(self): super().__init__() self.net = nn.Sequential( # 256×6×6 → 128×12×12 nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), # 128×12×12 → 64×24×24 nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), # 64×24×24 → 32×48×48 nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), # 32×48×48 → 16×96×96 ← extra layer gives the 2× upscale nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(16), nn.ReLU(inplace=True), # 16×96×96 → 3×96×96 nn.Conv2d(16, 3, kernel_size=3, padding=1), nn.Sigmoid(), ) def forward(self, x): return self.net(x) class SuperResAutoencoder(nn.Module): """Denoise + 2× upscale: noisy 48×48 → clean 96×96. Fully convolutional — can accept any input size and will output 2× that size. """ def __init__(self): super().__init__() self.encoder = SREncoder() self.decoder = SRDecoder() def forward(self, x): return self.decoder(self.encoder(x)) def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad)