File size: 5,926 Bytes
2bf4ec2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
"""
Residual Convolutional Autoencoder for Image Reconstruction
Architecture: 6-layer encoder/decoder with residual blocks
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class AEResidualBlock(nn.Module):
"""Residual block with batch normalization and dropout"""
def __init__(self, channels, dropout=0.1):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.dropout(out)
out = self.bn2(self.conv2(out))
out += residual
return self.relu(out)
class ResidualConvAutoencoder(nn.Module):
"""
Deep Convolutional Autoencoder with Residual Connections
Args:
latent_dim (int): Dimension of latent space (512 or 768)
dropout (float): Dropout rate for regularization (0.15 or 0.20)
Input: (B, 3, 256, 256) RGB images
Output: (B, 3, 256, 256) Reconstructed images + (B, latent_dim) latent codes
"""
def __init__(self, latent_dim=512, dropout=0.15):
super().__init__()
self.latent_dim = latent_dim
self.dropout = dropout
# Encoder: 256x256 -> 4x4
self.encoder = nn.Sequential(
# 256 -> 128
nn.Conv2d(3, 64, 4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
AEResidualBlock(64, dropout),
# 128 -> 64
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
AEResidualBlock(128, dropout),
# 64 -> 32
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
AEResidualBlock(256, dropout),
# 32 -> 16
nn.Conv2d(256, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
AEResidualBlock(512, dropout),
# 16 -> 8
nn.Conv2d(512, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
AEResidualBlock(512, dropout),
# 8 -> 4
nn.Conv2d(512, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
# Latent space projection
self.fc_encoder = nn.Linear(512 * 4 * 4, latent_dim)
self.fc_decoder = nn.Linear(latent_dim, 512 * 4 * 4)
# Decoder: 4x4 -> 256x256
self.decoder = nn.Sequential(
# 4 -> 8
nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
AEResidualBlock(512, dropout),
# 8 -> 16
nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
AEResidualBlock(512, dropout),
# 16 -> 32
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
AEResidualBlock(256, dropout),
# 32 -> 64
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
AEResidualBlock(128, dropout),
# 64 -> 128
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
AEResidualBlock(64, dropout),
# 128 -> 256
nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
nn.Tanh() # Output in [-1, 1]
)
def forward(self, x):
"""
Forward pass
Args:
x: Input tensor (B, 3, 256, 256) in range [-1, 1]
Returns:
reconstructed: Reconstructed tensor (B, 3, 256, 256)
latent: Latent representation (B, latent_dim)
"""
# Encode
x = self.encoder(x)
x = x.view(x.size(0), -1)
latent = self.fc_encoder(x)
# Decode
x = self.fc_decoder(latent)
x = x.view(x.size(0), 512, 4, 4)
reconstructed = self.decoder(x)
return reconstructed, latent
def encode(self, x):
"""Get latent representation only"""
x = self.encoder(x)
x = x.view(x.size(0), -1)
return self.fc_encoder(x)
def decode(self, latent):
"""Reconstruct from latent code"""
x = self.fc_decoder(latent)
x = x.view(x.size(0), 512, 4, 4)
return self.decoder(x)
def load_model(checkpoint_path, latent_dim=512, dropout=0.15, device='cuda'):
"""
Load a trained model from checkpoint
Args:
checkpoint_path: Path to .pth checkpoint file
latent_dim: Latent dimension (512 for Model A, 768 for Model B)
dropout: Dropout rate (0.15 for Model A, 0.20 for Model B)
device: Device to load model on
Returns:
model: Loaded model in eval mode
checkpoint: Full checkpoint dict with metadata
"""
model = ResidualConvAutoencoder(latent_dim=latent_dim, dropout=dropout)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model.to(device)
return model, checkpoint
|