import torch import numpy as np import pickle import sys import os # We will need the StyleGAN2 repository code to load the .pkl files. # Instead of cloning, we'll try to use a more direct method if possible, # or advise the user to run a setup script. class StyleGAN2AgingEngine: def __init__(self, stylegan_path, age_vector_path, gender_vector_path, device="cuda"): self.device = torch.device(device) self.stylegan_path = stylegan_path self.age_vector_path = age_vector_path self.gender_vector_path = gender_vector_path self.G = None self.age_direction = None self.gender_direction = None def load_model(self): print(f"Loading StyleGAN2 from {self.stylegan_path}...") # StyleGAN2-ADA weights are stored in pickle files with complex object structures. # To load them, we typically need the original repository in the PYTHONPATH. # We will add a check for the user. if not os.path.exists(self.stylegan_path): raise FileNotFoundError("StyleGAN2 weights not found. Please run download_weights.py first.") with open(self.stylegan_path, 'rb') as f: # Note: This requires the 'dnnlib' and 'torch_utils' from the StyleGAN2 repo # We will handle the setup in a separate step or provided as a utility. self.G = pickle.load(f)['G_ema'].to(self.device) print("Loading Age Boundary vector...") self.age_direction = np.load(self.age_vector_path) self.age_direction = torch.from_numpy(self.age_direction).to(self.device).float() print("Loading Gender Boundary vector...") if os.path.exists(self.gender_vector_path): self.gender_direction = np.load(self.gender_vector_path) self.gender_direction = torch.from_numpy(self.gender_direction).to(self.device).float() def generate_at_age(self, latent_w, age_coeff, gender_coeff=0.0): """ latent_w: The W space latent (shape: 1, 18, 512) age_coeff: Higher = Older gender_coeff: Shifting between Male/Female """ w_aged = latent_w.clone() # Apply Age manipulation w_aged += age_coeff * self.age_direction.view(1, 1, 512) # Apply Gender correction if self.gender_direction is not None: w_aged += gender_coeff * self.gender_direction.view(1, 1, 512) # Generate the image with torch.no_grad(): img = self.G.synthesis(w_aged, noise_mode='const') img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) return img.cpu().numpy()[0] def project_image(self, target_img_pil, num_steps=100): """ Inverts a real image into StyleGAN2 latent space (W+). This takes ~15-30 seconds on a GTX 1650. """ import torch.nn.functional as F # Preprocess image target_img = np.array(target_img_pil.convert('RGB')) target_img = torch.from_numpy(target_img).permute(2, 0, 1).unsqueeze(0).to(self.device).to(torch.float32) target_img = (target_img / 127.5 - 1.0) target_img = sys.modules['torch.nn.functional'].interpolate(target_img, size=(self.G.img_resolution, self.G.img_resolution), mode='area') # Clear memory before heavy lifting torch.cuda.empty_cache() # Find initial W (average W) w_avg = self.G.mapping.w_avg w_pivot = w_avg.clone().detach().unsqueeze(0).repeat(1, self.G.mapping.num_ws, 1) w_opt = w_pivot.clone().detach().requires_grad_(True) optimizer = torch.optim.Adam([w_opt], lr=0.1) print(f"Starting Inversion (Identity Lock)...") try: for step in range(num_steps): optimizer.zero_grad() synth_img = self.G.synthesis(w_opt, noise_mode='const') loss = sys.modules['torch.nn.functional'].mse_loss(synth_img, target_img) loss.backward() optimizer.step() if step % 2 == 0: print(f" Step {step}/{num_steps}, Loss: {loss.item():.4f}") print("Inversion Complete.") except Exception as e: print(f"Critical Error during inversion: {e}") import traceback traceback.print_exc() return w_pivot.detach() # Return average face as fallback finally: torch.cuda.empty_cache() return w_opt.detach() # Instructions for the USER: # To run this, we need the StyleGAN2-ADA repo. # I will create a setup_path1.py script to handle this automatically.