""" ResidualConvAutoencoder - Deepfake Detection Model Architecture: 5-stage encoder-decoder with residual blocks """ import torch import torch.nn as nn class ResidualBlock(nn.Module): """Residual block with two conv layers and skip connection""" def __init__(self, channels): super().__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.bn2 = nn.BatchNorm2d(channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): residual = x out = self.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += residual return self.relu(out) class ResidualConvAutoencoder(nn.Module): """ Residual Convolutional Autoencoder for image reconstruction and deepfake detection. Args: latent_dim (int): Dimension of latent space (default: 512) Input: x: Tensor of shape (batch_size, 3, 128, 128), values in [-1, 1] Output: reconstructed: Tensor of shape (batch_size, 3, 128, 128), values in [-1, 1] latent: Tensor of shape (batch_size, latent_dim) """ def __init__(self, latent_dim=512): super().__init__() self.latent_dim = latent_dim # Encoder: 128x128 -> 4x4 self.encoder = nn.Sequential( # Stage 1: 128 -> 64 nn.Conv2d(3, 64, 4, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), ResidualBlock(64), # Stage 2: 64 -> 32 nn.Conv2d(64, 128, 4, stride=2, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), ResidualBlock(128), # Stage 3: 32 -> 16 nn.Conv2d(128, 256, 4, stride=2, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ResidualBlock(256), # Stage 4: 16 -> 8 nn.Conv2d(256, 512, 4, stride=2, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), ResidualBlock(512), # Stage 5: 8 -> 4 nn.Conv2d(512, 512, 4, stride=2, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), ) # Bottleneck self.fc_encoder = nn.Linear(512 * 4 * 4, latent_dim) self.fc_decoder = nn.Linear(latent_dim, 512 * 4 * 4) # Decoder: 4x4 -> 128x128 self.decoder = nn.Sequential( # Stage 1: 4 -> 8 nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), ResidualBlock(512), # Stage 2: 8 -> 16 nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ResidualBlock(256), # Stage 3: 16 -> 32 nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), ResidualBlock(128), # Stage 4: 32 -> 64 nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), ResidualBlock(64), # Stage 5: 64 -> 128 nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1), nn.Tanh() # Output in [-1, 1] ) def forward(self, x): """ Forward pass through the autoencoder. Args: x: Input tensor of shape (batch_size, 3, 128, 128) Returns: reconstructed: Reconstructed image of shape (batch_size, 3, 128, 128) latent: Latent representation of shape (batch_size, latent_dim) """ # Encode x = self.encoder(x) x = x.view(x.size(0), -1) latent = self.fc_encoder(x) # Decode x = self.fc_decoder(latent) x = x.view(x.size(0), 512, 4, 4) reconstructed = self.decoder(x) return reconstructed, latent def encode(self, x): """Extract latent representation only""" x = self.encoder(x) x = x.view(x.size(0), -1) latent = self.fc_encoder(x) return latent def decode(self, latent): """Reconstruct from latent representation""" x = self.fc_decoder(latent) x = x.view(x.size(0), 512, 4, 4) reconstructed = self.decoder(x) return reconstructed def reconstruction_error(self, x, reduction='mean'): """ Calculate per-sample reconstruction error (MSE). Useful for anomaly/deepfake detection. Args: x: Input tensor reduction: 'mean' for average error, 'none' for per-sample errors Returns: Reconstruction error (MSE) """ reconstructed, _ = self.forward(x) error = (reconstructed - x) ** 2 if reduction == 'mean': return error.mean() elif reduction == 'none': return error.view(x.size(0), -1).mean(dim=1) else: raise ValueError(f"Unknown reduction: {reduction}") def load_model(checkpoint_path, device='cuda'): """ Load pretrained model from checkpoint. Args: checkpoint_path: Path to .ckpt file device: 'cuda' or 'cpu' Returns: model: Loaded ResidualConvAutoencoder in eval mode """ model = ResidualConvAutoencoder(latent_dim=512) checkpoint = torch.load(checkpoint_path, map_location=device) if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) elif 'state_dict' in checkpoint: model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) model = model.to(device) model.eval() return model