""" Residual Convolutional Autoencoder for Image Reconstruction Architecture: 6-layer encoder/decoder with residual blocks """ import torch import torch.nn as nn import torch.nn.functional as F class AEResidualBlock(nn.Module): """Residual block with batch normalization and dropout""" def __init__(self, channels, dropout=0.1): super().__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.bn2 = nn.BatchNorm2d(channels) self.relu = nn.ReLU(inplace=True) self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity() def forward(self, x): residual = x out = self.relu(self.bn1(self.conv1(x))) out = self.dropout(out) out = self.bn2(self.conv2(out)) out += residual return self.relu(out) class ResidualConvAutoencoder(nn.Module): """ Deep Convolutional Autoencoder with Residual Connections Args: latent_dim (int): Dimension of latent space (512 or 768) dropout (float): Dropout rate for regularization (0.15 or 0.20) Input: (B, 3, 256, 256) RGB images Output: (B, 3, 256, 256) Reconstructed images + (B, latent_dim) latent codes """ def __init__(self, latent_dim=512, dropout=0.15): super().__init__() self.latent_dim = latent_dim self.dropout = dropout # Encoder: 256x256 -> 4x4 self.encoder = nn.Sequential( # 256 -> 128 nn.Conv2d(3, 64, 4, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), AEResidualBlock(64, dropout), # 128 -> 64 nn.Conv2d(64, 128, 4, stride=2, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), AEResidualBlock(128, dropout), # 64 -> 32 nn.Conv2d(128, 256, 4, stride=2, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), AEResidualBlock(256, dropout), # 32 -> 16 nn.Conv2d(256, 512, 4, stride=2, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), AEResidualBlock(512, dropout), # 16 -> 8 nn.Conv2d(512, 512, 4, stride=2, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), AEResidualBlock(512, dropout), # 8 -> 4 nn.Conv2d(512, 512, 4, stride=2, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), ) # Latent space projection self.fc_encoder = nn.Linear(512 * 4 * 4, latent_dim) self.fc_decoder = nn.Linear(latent_dim, 512 * 4 * 4) # Decoder: 4x4 -> 256x256 self.decoder = nn.Sequential( # 4 -> 8 nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), AEResidualBlock(512, dropout), # 8 -> 16 nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), AEResidualBlock(512, dropout), # 16 -> 32 nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), AEResidualBlock(256, dropout), # 32 -> 64 nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), AEResidualBlock(128, dropout), # 64 -> 128 nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), AEResidualBlock(64, dropout), # 128 -> 256 nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1), nn.Tanh() # Output in [-1, 1] ) def forward(self, x): """ Forward pass Args: x: Input tensor (B, 3, 256, 256) in range [-1, 1] Returns: reconstructed: Reconstructed tensor (B, 3, 256, 256) latent: Latent representation (B, latent_dim) """ # Encode x = self.encoder(x) x = x.view(x.size(0), -1) latent = self.fc_encoder(x) # Decode x = self.fc_decoder(latent) x = x.view(x.size(0), 512, 4, 4) reconstructed = self.decoder(x) return reconstructed, latent def encode(self, x): """Get latent representation only""" x = self.encoder(x) x = x.view(x.size(0), -1) return self.fc_encoder(x) def decode(self, latent): """Reconstruct from latent code""" x = self.fc_decoder(latent) x = x.view(x.size(0), 512, 4, 4) return self.decoder(x) def load_model(checkpoint_path, latent_dim=512, dropout=0.15, device='cuda'): """ Load a trained model from checkpoint Args: checkpoint_path: Path to .pth checkpoint file latent_dim: Latent dimension (512 for Model A, 768 for Model B) dropout: Dropout rate (0.15 for Model A, 0.20 for Model B) device: Device to load model on Returns: model: Loaded model in eval mode checkpoint: Full checkpoint dict with metadata """ model = ResidualConvAutoencoder(latent_dim=latent_dim, dropout=dropout) checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() model.to(device) return model, checkpoint