Age-Synthesis / gap_engine.py
Bhaskr's picture
Initial clean deployment
2ab0040
import torch
import numpy as np
import pickle
import sys
import os
# We will need the StyleGAN2 repository code to load the .pkl files.
# Instead of cloning, we'll try to use a more direct method if possible,
# or advise the user to run a setup script.
class StyleGAN2AgingEngine:
def __init__(self, stylegan_path, age_vector_path, gender_vector_path, device="cuda"):
self.device = torch.device(device)
self.stylegan_path = stylegan_path
self.age_vector_path = age_vector_path
self.gender_vector_path = gender_vector_path
self.G = None
self.age_direction = None
self.gender_direction = None
def load_model(self):
print(f"Loading StyleGAN2 from {self.stylegan_path}...")
# StyleGAN2-ADA weights are stored in pickle files with complex object structures.
# To load them, we typically need the original repository in the PYTHONPATH.
# We will add a check for the user.
if not os.path.exists(self.stylegan_path):
raise FileNotFoundError("StyleGAN2 weights not found. Please run download_weights.py first.")
with open(self.stylegan_path, 'rb') as f:
# Note: This requires the 'dnnlib' and 'torch_utils' from the StyleGAN2 repo
# We will handle the setup in a separate step or provided as a utility.
self.G = pickle.load(f)['G_ema'].to(self.device)
print("Loading Age Boundary vector...")
self.age_direction = np.load(self.age_vector_path)
self.age_direction = torch.from_numpy(self.age_direction).to(self.device).float()
print("Loading Gender Boundary vector...")
if os.path.exists(self.gender_vector_path):
self.gender_direction = np.load(self.gender_vector_path)
self.gender_direction = torch.from_numpy(self.gender_direction).to(self.device).float()
def generate_at_age(self, latent_w, age_coeff, gender_coeff=0.0):
"""
latent_w: The W space latent (shape: 1, 18, 512)
age_coeff: Higher = Older
gender_coeff: Shifting between Male/Female
"""
w_aged = latent_w.clone()
# Apply Age manipulation
w_aged += age_coeff * self.age_direction.view(1, 1, 512)
# Apply Gender correction
if self.gender_direction is not None:
w_aged += gender_coeff * self.gender_direction.view(1, 1, 512)
# Generate the image
with torch.no_grad():
img = self.G.synthesis(w_aged, noise_mode='const')
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
return img.cpu().numpy()[0]
def project_image(self, target_img_pil, num_steps=100):
"""
Inverts a real image into StyleGAN2 latent space (W+).
This takes ~15-30 seconds on a GTX 1650.
"""
import torch.nn.functional as F
# Preprocess image
target_img = np.array(target_img_pil.convert('RGB'))
target_img = torch.from_numpy(target_img).permute(2, 0, 1).unsqueeze(0).to(self.device).to(torch.float32)
target_img = (target_img / 127.5 - 1.0)
target_img = sys.modules['torch.nn.functional'].interpolate(target_img, size=(self.G.img_resolution, self.G.img_resolution), mode='area')
# Clear memory before heavy lifting
torch.cuda.empty_cache()
# Find initial W (average W)
w_avg = self.G.mapping.w_avg
w_pivot = w_avg.clone().detach().unsqueeze(0).repeat(1, self.G.mapping.num_ws, 1)
w_opt = w_pivot.clone().detach().requires_grad_(True)
optimizer = torch.optim.Adam([w_opt], lr=0.1)
print(f"Starting Inversion (Identity Lock)...")
try:
for step in range(num_steps):
optimizer.zero_grad()
synth_img = self.G.synthesis(w_opt, noise_mode='const')
loss = sys.modules['torch.nn.functional'].mse_loss(synth_img, target_img)
loss.backward()
optimizer.step()
if step % 2 == 0:
print(f" Step {step}/{num_steps}, Loss: {loss.item():.4f}")
print("Inversion Complete.")
except Exception as e:
print(f"Critical Error during inversion: {e}")
import traceback
traceback.print_exc()
return w_pivot.detach() # Return average face as fallback
finally:
torch.cuda.empty_cache()
return w_opt.detach()
# Instructions for the USER:
# To run this, we need the StyleGAN2-ADA repo.
# I will create a setup_path1.py script to handle this automatically.