Spaces:
Sleeping
Sleeping
File size: 2,732 Bytes
2ab0040 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | import torch
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, latent_dim):
super(Encoder, self).__init__()
# A simple convolutional encoder for demonstration
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 4, 2, 1), # 64x64 -> 32x32
nn.ReLU(),
nn.Conv2d(32, 64, 4, 2, 1), # 32x32 -> 16x16
nn.ReLU(),
nn.Conv2d(64, 128, 4, 2, 1), # 16x16 -> 8x8
nn.ReLU(),
nn.Flatten()
)
# Assuming input image of 64x64
self.fc_mu = nn.Linear(128 * 8 * 8, latent_dim)
self.fc_logvar = nn.Linear(128 * 8 * 8, latent_dim)
def forward(self, x):
features = self.conv(x)
mu = self.fc_mu(features)
logvar = self.fc_logvar(features)
return mu, logvar
class Decoder(nn.Module):
def __init__(self, latent_dim, condition_dim=1):
super(Decoder, self).__init__()
# The decoder takes the latent vector PLUS the age condition
self.fc = nn.Linear(latent_dim + condition_dim, 128 * 8 * 8)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, 1), # 8x8 -> 16x16
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, 2, 1), # 16x16 -> 32x32
nn.ReLU(),
nn.ConvTranspose2d(32, 3, 4, 2, 1), # 32x32 -> 64x64
nn.Sigmoid() # Output pixels between 0 and 1
)
def forward(self, z, age_condition):
# Concatenate latent identity with age condition
z_cond = torch.cat((z, age_condition), dim=1)
hidden = self.fc(z_cond)
hidden = hidden.view(-1, 128, 8, 8)
out_img = self.deconv(hidden)
return out_img
class GAP_CVAE(nn.Module):
def __init__(self, latent_dim=128):
super(GAP_CVAE, self).__init__()
self.encoder = Encoder(latent_dim)
self.decoder = Decoder(latent_dim, condition_dim=1)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x, age):
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logvar)
reconstructed = self.decoder(z, age)
return reconstructed, mu, logvar
def simulate_age(self, x, target_age):
"""Used for inference when we want to change the age of an image"""
device = next(self.parameters()).device
x = x.to(device)
target_age = target_age.to(device)
# 1. Extract Identity Latent (mu)
mu, _ = self.encoder(x)
# 2. Decode with new target age
projected_image = self.decoder(mu, target_age)
return projected_image
|