File size: 4,758 Bytes
2ab0040
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import numpy as np
import pickle
import sys
import os

# We will need the StyleGAN2 repository code to load the .pkl files.
# Instead of cloning, we'll try to use a more direct method if possible, 
# or advise the user to run a setup script.

class StyleGAN2AgingEngine:
    def __init__(self, stylegan_path, age_vector_path, gender_vector_path, device="cuda"):
        self.device = torch.device(device)
        self.stylegan_path = stylegan_path
        self.age_vector_path = age_vector_path
        self.gender_vector_path = gender_vector_path
        self.G = None
        self.age_direction = None
        self.gender_direction = None
        
    def load_model(self):
        print(f"Loading StyleGAN2 from {self.stylegan_path}...")
        
        # StyleGAN2-ADA weights are stored in pickle files with complex object structures.
        # To load them, we typically need the original repository in the PYTHONPATH.
        # We will add a check for the user.
        
        if not os.path.exists(self.stylegan_path):
            raise FileNotFoundError("StyleGAN2 weights not found. Please run download_weights.py first.")
            
        with open(self.stylegan_path, 'rb') as f:
            # Note: This requires the 'dnnlib' and 'torch_utils' from the StyleGAN2 repo
            # We will handle the setup in a separate step or provided as a utility.
            self.G = pickle.load(f)['G_ema'].to(self.device)
            
        print("Loading Age Boundary vector...")
        self.age_direction = np.load(self.age_vector_path)
        self.age_direction = torch.from_numpy(self.age_direction).to(self.device).float()

        print("Loading Gender Boundary vector...")
        if os.path.exists(self.gender_vector_path):
            self.gender_direction = np.load(self.gender_vector_path)
            self.gender_direction = torch.from_numpy(self.gender_direction).to(self.device).float()

    def generate_at_age(self, latent_w, age_coeff, gender_coeff=0.0):
        """
        latent_w: The W space latent (shape: 1, 18, 512)
        age_coeff: Higher = Older
        gender_coeff: Shifting between Male/Female
        """
        w_aged = latent_w.clone()
        
        # Apply Age manipulation
        w_aged += age_coeff * self.age_direction.view(1, 1, 512)
        
        # Apply Gender correction
        if self.gender_direction is not None:
             w_aged += gender_coeff * self.gender_direction.view(1, 1, 512)
        
        # Generate the image
        with torch.no_grad():
            img = self.G.synthesis(w_aged, noise_mode='const')
            
        img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
        return img.cpu().numpy()[0]

    def project_image(self, target_img_pil, num_steps=100):
        """
        Inverts a real image into StyleGAN2 latent space (W+).
        This takes ~15-30 seconds on a GTX 1650.
        """
        import torch.nn.functional as F
        
        # Preprocess image
        target_img = np.array(target_img_pil.convert('RGB'))
        target_img = torch.from_numpy(target_img).permute(2, 0, 1).unsqueeze(0).to(self.device).to(torch.float32)
        target_img = (target_img / 127.5 - 1.0)
        target_img = sys.modules['torch.nn.functional'].interpolate(target_img, size=(self.G.img_resolution, self.G.img_resolution), mode='area')

        # Clear memory before heavy lifting
        torch.cuda.empty_cache()

        # Find initial W (average W)
        w_avg = self.G.mapping.w_avg
        w_pivot = w_avg.clone().detach().unsqueeze(0).repeat(1, self.G.mapping.num_ws, 1)
        w_opt = w_pivot.clone().detach().requires_grad_(True)
        
        optimizer = torch.optim.Adam([w_opt], lr=0.1)
        
        print(f"Starting Inversion (Identity Lock)...")
        try:
            for step in range(num_steps):
                optimizer.zero_grad()
                synth_img = self.G.synthesis(w_opt, noise_mode='const')
                loss = sys.modules['torch.nn.functional'].mse_loss(synth_img, target_img)
                loss.backward()
                optimizer.step()
                if step % 2 == 0:
                    print(f"  Step {step}/{num_steps}, Loss: {loss.item():.4f}")
            print("Inversion Complete.")
        except Exception as e:
            print(f"Critical Error during inversion: {e}")
            import traceback
            traceback.print_exc()
            return w_pivot.detach() # Return average face as fallback
        finally:
            torch.cuda.empty_cache()
                
        return w_opt.detach()

# Instructions for the USER:
# To run this, we need the StyleGAN2-ADA repo.
# I will create a setup_path1.py script to handle this automatically.