Age-Synthesis / gap_model.py
Bhaskr's picture
Initial clean deployment
2ab0040
import torch
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, latent_dim):
super(Encoder, self).__init__()
# A simple convolutional encoder for demonstration
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 4, 2, 1), # 64x64 -> 32x32
nn.ReLU(),
nn.Conv2d(32, 64, 4, 2, 1), # 32x32 -> 16x16
nn.ReLU(),
nn.Conv2d(64, 128, 4, 2, 1), # 16x16 -> 8x8
nn.ReLU(),
nn.Flatten()
)
# Assuming input image of 64x64
self.fc_mu = nn.Linear(128 * 8 * 8, latent_dim)
self.fc_logvar = nn.Linear(128 * 8 * 8, latent_dim)
def forward(self, x):
features = self.conv(x)
mu = self.fc_mu(features)
logvar = self.fc_logvar(features)
return mu, logvar
class Decoder(nn.Module):
def __init__(self, latent_dim, condition_dim=1):
super(Decoder, self).__init__()
# The decoder takes the latent vector PLUS the age condition
self.fc = nn.Linear(latent_dim + condition_dim, 128 * 8 * 8)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, 1), # 8x8 -> 16x16
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, 2, 1), # 16x16 -> 32x32
nn.ReLU(),
nn.ConvTranspose2d(32, 3, 4, 2, 1), # 32x32 -> 64x64
nn.Sigmoid() # Output pixels between 0 and 1
)
def forward(self, z, age_condition):
# Concatenate latent identity with age condition
z_cond = torch.cat((z, age_condition), dim=1)
hidden = self.fc(z_cond)
hidden = hidden.view(-1, 128, 8, 8)
out_img = self.deconv(hidden)
return out_img
class GAP_CVAE(nn.Module):
def __init__(self, latent_dim=128):
super(GAP_CVAE, self).__init__()
self.encoder = Encoder(latent_dim)
self.decoder = Decoder(latent_dim, condition_dim=1)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x, age):
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logvar)
reconstructed = self.decoder(z, age)
return reconstructed, mu, logvar
def simulate_age(self, x, target_age):
"""Used for inference when we want to change the age of an image"""
device = next(self.parameters()).device
x = x.to(device)
target_age = target_age.to(device)
# 1. Extract Identity Latent (mu)
mu, _ = self.encoder(x)
# 2. Decode with new target age
projected_image = self.decoder(mu, target_age)
return projected_image