ash12321's picture
Create model.py
57eeb52 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
# ----------------------------------------------------
# A helper block for the Residual Connection
# ----------------------------------------------------
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# Skip connection for differing channels/stride
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
# ----------------------------------------------------
# The Main Residual Autoencoder Model
# ----------------------------------------------------
class ResidualConvAutoencoder(pl.LightningModule):
def __init__(self, latent_dim=512, dropout_rate=0.2):
super().__init__()
self.latent_dim = latent_dim
# --- Encoder ---
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), # 32x32 -> 32x32
ResidualBlock(64, 128, stride=2), # 32x32 -> 16x16
ResidualBlock(128, 256, stride=2), # 16x16 -> 8x8
ResidualBlock(256, 512, stride=2), # 8x8 -> 4x4
nn.Flatten(),
nn.Linear(512 * 4 * 4, self.latent_dim),
nn.Dropout(dropout_rate)
)
# --- Decoder ---
self.decoder = nn.Sequential(
nn.Linear(self.latent_dim, 512 * 4 * 4),
nn.Unflatten(1, (512, 4, 4)),
ResidualBlock(512, 256),
nn.Upsample(scale_factor=2, mode='nearest'), # 4x4 -> 8x8
ResidualBlock(256, 128),
nn.Upsample(scale_factor=2, mode='nearest'), # 8x8 -> 16x16
ResidualBlock(128, 64),
nn.Upsample(scale_factor=2, mode='nearest'), # 16x16 -> 32x32
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
nn.Sigmoid() # Output pixel values between 0 and 1
)
def forward(self, x):
z = self.encoder(x)
recon = self.decoder(z)
return recon
# Placeholder training step (not needed for deployment file, but required for class completeness)
def training_step(self, batch, batch_idx):
return torch.tensor(0.0)
# Placeholder configure_optimizers (not needed for deployment file, but required for class completeness)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())