Spaces:
Runtime error
Runtime error
| import torch | |
| import csv | |
| from options import Settings | |
| import os | |
| class GanSpace(): | |
| def __init__(self) -> None: | |
| self.gan_space_configs = {} | |
| with open(os.path.join(Settings.ganspace_directions, 'ganspace_configs.csv'), "r") as f: | |
| reader = csv.reader(f, delimiter="\t") | |
| for row in reader: | |
| key = row.pop(0) | |
| self.gan_space_configs[key] = list(map(int, row)) | |
| def edit(self, latent, cfg): | |
| with torch.no_grad(): | |
| self.load_ganspace_pca() | |
| gan_space_config = self.gan_space_configs[cfg.edit] | |
| gan_space_config[-1] = cfg.strength | |
| return self.edit_ganspace(latent, gan_space_config) | |
| def load_ganspace_pca(self): | |
| try: # Check if loaded | |
| getattr(self, f"pca") | |
| except: | |
| pca = torch.load(os.path.join(Settings.ganspace_directions, 'ffhq_pca.pt')) | |
| setattr(self, f"pca", pca) | |
| def edit_ganspace(self, latents, config): | |
| edit_latents = [] | |
| pca_idx, start, end, strength = config | |
| for latent in latents: | |
| delta = self.get_delta( latent, pca_idx, strength) | |
| delta_padded = torch.zeros(latent.shape).to(Settings.device) | |
| delta_padded[start:end] += delta.repeat(end - start, 1) | |
| edit_latents.append(latent + delta_padded) | |
| return torch.stack(edit_latents) | |
| def get_delta(self, latent, idx, strength): | |
| # pca: ganspace checkpoint. latent: (16, 512) w+ | |
| w_centered = latent - self.pca['mean'].to(Settings.device) | |
| lat_comp = self.pca['comp'].to(Settings.device) | |
| lat_std = self.pca['std'].to(Settings.device) | |
| w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx] | |
| delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx] | |
| return delta |