import torch import torch.nn as nn import torch.nn.functional as F import pytorch_lightning as pl # ---------------------------------------------------- # A helper block for the Residual Connection # ---------------------------------------------------- class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) # Skip connection for differing channels/stride self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out # ---------------------------------------------------- # The Main Residual Autoencoder Model # ---------------------------------------------------- class ResidualConvAutoencoder(pl.LightningModule): def __init__(self, latent_dim=512, dropout_rate=0.2): super().__init__() self.latent_dim = latent_dim # --- Encoder --- self.encoder = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), # 32x32 -> 32x32 ResidualBlock(64, 128, stride=2), # 32x32 -> 16x16 ResidualBlock(128, 256, stride=2), # 16x16 -> 8x8 ResidualBlock(256, 512, stride=2), # 8x8 -> 4x4 nn.Flatten(), nn.Linear(512 * 4 * 4, self.latent_dim), nn.Dropout(dropout_rate) ) # --- Decoder --- self.decoder = nn.Sequential( nn.Linear(self.latent_dim, 512 * 4 * 4), nn.Unflatten(1, (512, 4, 4)), ResidualBlock(512, 256), nn.Upsample(scale_factor=2, mode='nearest'), # 4x4 -> 8x8 ResidualBlock(256, 128), nn.Upsample(scale_factor=2, mode='nearest'), # 8x8 -> 16x16 ResidualBlock(128, 64), nn.Upsample(scale_factor=2, mode='nearest'), # 16x16 -> 32x32 nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1), nn.Sigmoid() # Output pixel values between 0 and 1 ) def forward(self, x): z = self.encoder(x) recon = self.decoder(z) return recon # Placeholder training step (not needed for deployment file, but required for class completeness) def training_step(self, batch, batch_idx): return torch.tensor(0.0) # Placeholder configure_optimizers (not needed for deployment file, but required for class completeness) def configure_optimizers(self): return torch.optim.Adam(self.parameters())