ash12321's picture
Upload model.py with huggingface_hub
2bf4ec2 verified
"""
Residual Convolutional Autoencoder for Image Reconstruction
Architecture: 6-layer encoder/decoder with residual blocks
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class AEResidualBlock(nn.Module):
"""Residual block with batch normalization and dropout"""
def __init__(self, channels, dropout=0.1):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.dropout(out)
out = self.bn2(self.conv2(out))
out += residual
return self.relu(out)
class ResidualConvAutoencoder(nn.Module):
"""
Deep Convolutional Autoencoder with Residual Connections
Args:
latent_dim (int): Dimension of latent space (512 or 768)
dropout (float): Dropout rate for regularization (0.15 or 0.20)
Input: (B, 3, 256, 256) RGB images
Output: (B, 3, 256, 256) Reconstructed images + (B, latent_dim) latent codes
"""
def __init__(self, latent_dim=512, dropout=0.15):
super().__init__()
self.latent_dim = latent_dim
self.dropout = dropout
# Encoder: 256x256 -> 4x4
self.encoder = nn.Sequential(
# 256 -> 128
nn.Conv2d(3, 64, 4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
AEResidualBlock(64, dropout),
# 128 -> 64
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
AEResidualBlock(128, dropout),
# 64 -> 32
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
AEResidualBlock(256, dropout),
# 32 -> 16
nn.Conv2d(256, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
AEResidualBlock(512, dropout),
# 16 -> 8
nn.Conv2d(512, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
AEResidualBlock(512, dropout),
# 8 -> 4
nn.Conv2d(512, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
# Latent space projection
self.fc_encoder = nn.Linear(512 * 4 * 4, latent_dim)
self.fc_decoder = nn.Linear(latent_dim, 512 * 4 * 4)
# Decoder: 4x4 -> 256x256
self.decoder = nn.Sequential(
# 4 -> 8
nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
AEResidualBlock(512, dropout),
# 8 -> 16
nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
AEResidualBlock(512, dropout),
# 16 -> 32
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
AEResidualBlock(256, dropout),
# 32 -> 64
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
AEResidualBlock(128, dropout),
# 64 -> 128
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
AEResidualBlock(64, dropout),
# 128 -> 256
nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
nn.Tanh() # Output in [-1, 1]
)
def forward(self, x):
"""
Forward pass
Args:
x: Input tensor (B, 3, 256, 256) in range [-1, 1]
Returns:
reconstructed: Reconstructed tensor (B, 3, 256, 256)
latent: Latent representation (B, latent_dim)
"""
# Encode
x = self.encoder(x)
x = x.view(x.size(0), -1)
latent = self.fc_encoder(x)
# Decode
x = self.fc_decoder(latent)
x = x.view(x.size(0), 512, 4, 4)
reconstructed = self.decoder(x)
return reconstructed, latent
def encode(self, x):
"""Get latent representation only"""
x = self.encoder(x)
x = x.view(x.size(0), -1)
return self.fc_encoder(x)
def decode(self, latent):
"""Reconstruct from latent code"""
x = self.fc_decoder(latent)
x = x.view(x.size(0), 512, 4, 4)
return self.decoder(x)
def load_model(checkpoint_path, latent_dim=512, dropout=0.15, device='cuda'):
"""
Load a trained model from checkpoint
Args:
checkpoint_path: Path to .pth checkpoint file
latent_dim: Latent dimension (512 for Model A, 768 for Model B)
dropout: Dropout rate (0.15 for Model A, 0.20 for Model B)
device: Device to load model on
Returns:
model: Loaded model in eval mode
checkpoint: Full checkpoint dict with metadata
"""
model = ResidualConvAutoencoder(latent_dim=latent_dim, dropout=dropout)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model.to(device)
return model, checkpoint