image-denoiser / models /autoencoder.py
Kajuto's picture
Initial commit - image denoiser + SR + MLOps stack
8b83582
import torch
import torch.nn as nn
# ---------------------------------------------------------------------------
# Original Denoising Autoencoder (96Γ—96 β†’ 96Γ—96)
# ---------------------------------------------------------------------------
class Encoder(nn.Module):
"""Convolutional encoder: 3Γ—96Γ—96 β†’ 256Γ—12Γ—12"""
def __init__(self):
super().__init__()
self.net = nn.Sequential(
# 3Γ—96Γ—96 β†’ 32Γ—96Γ—96
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
# 32Γ—96Γ—96 β†’ 64Γ—48Γ—48
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
# 64Γ—48Γ—48 β†’ 128Γ—24Γ—24
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
# 128Γ—24Γ—24 β†’ 256Γ—12Γ—12
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
class Decoder(nn.Module):
"""Convolutional decoder: 256Γ—12Γ—12 β†’ 3Γ—96Γ—96"""
def __init__(self):
super().__init__()
self.net = nn.Sequential(
# 256Γ—12Γ—12 β†’ 128Γ—24Γ—24
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
# 128Γ—24Γ—24 β†’ 64Γ—48Γ—48
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
# 64Γ—48Γ—48 β†’ 32Γ—96Γ—96
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
# 32Γ—96Γ—96 β†’ 3Γ—96Γ—96
nn.Conv2d(32, 3, kernel_size=3, padding=1),
nn.Sigmoid(),
)
def forward(self, x):
return self.net(x)
class DenoisingAutoencoder(nn.Module):
"""Full denoising autoencoder: noisy image β†’ clean image."""
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
# ---------------------------------------------------------------------------
# Super-Resolution Autoencoder (noisy 48Γ—48 β†’ clean 96Γ—96)
# ---------------------------------------------------------------------------
class SREncoder(nn.Module):
"""Convolutional encoder: 3Γ—48Γ—48 β†’ 256Γ—6Γ—6"""
def __init__(self):
super().__init__()
self.net = nn.Sequential(
# 3Γ—48Γ—48 β†’ 32Γ—48Γ—48
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
# 32Γ—48Γ—48 β†’ 64Γ—24Γ—24
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
# 64Γ—24Γ—24 β†’ 128Γ—12Γ—12
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
# 128Γ—12Γ—12 β†’ 256Γ—6Γ—6
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
class SRDecoder(nn.Module):
"""Convolutional decoder: 256Γ—6Γ—6 β†’ 3Γ—96Γ—96 (2Γ— upscale)"""
def __init__(self):
super().__init__()
self.net = nn.Sequential(
# 256Γ—6Γ—6 β†’ 128Γ—12Γ—12
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
# 128Γ—12Γ—12 β†’ 64Γ—24Γ—24
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
# 64Γ—24Γ—24 β†’ 32Γ—48Γ—48
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
# 32Γ—48Γ—48 β†’ 16Γ—96Γ—96 ← extra layer gives the 2Γ— upscale
nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
# 16Γ—96Γ—96 β†’ 3Γ—96Γ—96
nn.Conv2d(16, 3, kernel_size=3, padding=1),
nn.Sigmoid(),
)
def forward(self, x):
return self.net(x)
class SuperResAutoencoder(nn.Module):
"""Denoise + 2Γ— upscale: noisy 48Γ—48 β†’ clean 96Γ—96.
Fully convolutional β€” can accept any input size and will output 2Γ— that size.
"""
def __init__(self):
super().__init__()
self.encoder = SREncoder()
self.decoder = SRDecoder()
def forward(self, x):
return self.decoder(self.encoder(x))
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)