Spaces:
Sleeping
Sleeping
Fix DeltaEditor device handling: add device parameter, error handling for missing files, replace .cuda() with .to(device)
96c63c3
| import os | |
| import sys | |
| import torch | |
| import pickle | |
| import argparse | |
| import numpy as np | |
| from editings import ganspace | |
| from editings.styleclip.mapper.styleclip_mapper import StyleCLIPMapper | |
| from editings.styleclip.mapper.gloabl_mapper import StyleCLIPGlobalDirection | |
| from editings.deltaedit.editor import DeltaEditor | |
| STYLESPACE_IDX = [ | |
| 0, | |
| 1, | |
| 1, | |
| 2, | |
| 2, | |
| 3, | |
| 4, | |
| 4, | |
| 5, | |
| 6, | |
| 6, | |
| 7, | |
| 8, | |
| 8, | |
| 9, | |
| 10, | |
| 10, | |
| 11, | |
| 12, | |
| 12, | |
| 13, | |
| 14, | |
| 14, | |
| 15, | |
| 16, | |
| 16, | |
| ] | |
| class LatentEditor: | |
| def __init__(self, domain="human_faces", device="cpu"): | |
| self.domain = domain | |
| self.device = torch.device(device) | |
| if self.domain == "human_faces": | |
| self.interfacegan_directions = { | |
| "age": "editings/interfacegan_directions/age.pt", | |
| "smile": "editings/interfacegan_directions/smile.pt", | |
| "rotation": "editings/interfacegan_directions/rotation.pt", | |
| } | |
| self.interfacegan_tensors = { | |
| name: torch.load(path, map_location=self.device) | |
| for name, path in self.interfacegan_directions.items() | |
| } | |
| self.ganspace_pca = torch.load("editings/ganspace_pca/ffhq_pca.pt", map_location=self.device) | |
| self.ganspace_directions = { | |
| "eye_openness": (54, 7, 8, 5), | |
| "trimmed_beard": (58, 7, 9, 7), | |
| "lipstick": (34, 10, 11, 20), | |
| "face_roundness": (37, 0, 5, 20.0), | |
| "nose_length": (51, 4, 5, -30.0), | |
| "eyebrow_thickness": (37, 8, 9, 20.0), | |
| "head_angle_up": (11, 1, 4, -10.5), | |
| "displeased": (36, 4, 7, 10.0), | |
| } | |
| self.styleclip_directions = { | |
| "afro": [False, False, True], | |
| "angry": [False, False, True], | |
| "beyonce": [False, False, False], | |
| "bobcut": [False, False, True], | |
| "bowlcut": [False, False, True], | |
| "curly_hair": [False, False, True], | |
| "hilary_clinton": [False, False, False], | |
| "depp": [False, False, False], | |
| "mohawk": [False, False, True], | |
| "purple_hair": [False, False, False], | |
| "surprised": [False, False, True], | |
| "taylor_swift": [False, False, False], | |
| "trump": [False, False, False], | |
| "zuckerberg": [False, False, False], | |
| } | |
| self.styleclip_global_editor = self.load_styleclip_global() | |
| self.stylespace_directions = { | |
| "black hair": [(12, 479)], | |
| "blond hair": [(12, 479), (12, 266)], | |
| "grey hair": [(11, 286)], | |
| "wavy hair": [(6, 500), (8, 128), (5, 92), (6, 394), (6, 323)], | |
| "bangs": [ | |
| (3, 259), | |
| (6, 285), | |
| (5, 414), | |
| (6, 128), | |
| (9, 295), | |
| (6, 322), | |
| (6, 487), | |
| (6, 504), | |
| ], | |
| "receding hairline": [(5, 414), (6, 322), (6, 497), (6, 504)], | |
| "smiling": [(6, 501)], | |
| "sslipstick": [(15, 45)], | |
| "sideburns": [(12, 237)], | |
| "goatee": [(9, 421)], | |
| "earrings": [(8, 81)], | |
| "glasses": [(3, 288), (2, 175), (3, 120), (2, 97)], | |
| "wear suit": [(9, 441), (8, 292), (11, 358), (6, 223)], | |
| "gender": [(9, 6)], | |
| } | |
| self.fs_directions = { | |
| "fs_glasses": "editings/bound/Eyeglasses_boundary.npy", | |
| "fs_smiling": "editings/bound/Smiling_boundary.npy", | |
| "fs_makeup": "editings/bound/Heavy_Makeup_boundary.npy" | |
| } | |
| self.deltaedit_editor = DeltaEditor(device=self.device) | |
| elif self.domain == "car": | |
| self.stylespace_directions = { | |
| "front": [(8, 411)], | |
| "headlights": [(8, 441), (9, 355)], | |
| "grill": [(9, 191)], | |
| "trees": [(9, 108)], | |
| "grass_ss": [(12, 107)], | |
| "sky": [(12, 76)], | |
| "hubcap": [(12, 113), (12, 439)], | |
| "car color": [(12, 142), (15, 227)], | |
| "logo": [(9, 185)], | |
| "wheel angle": [(8, 420)], | |
| } | |
| self.ganspace_pca = torch.load("editings/ganspace_pca/cars_pca.pt") | |
| self.ganspace_directions = { | |
| "pose_1": (0, 0, 5, 2), | |
| "pose_2": (0, 0, 5, -2), | |
| "cube": (16, 3, 6, 25), | |
| "color": (22, 9, 11, -8), | |
| "grass": (41, 9, 11, -18) | |
| } | |
| def load_styleclip_global(self): | |
| delta_i_c = torch.from_numpy(np.load("editings/styleclip/global_mapper_data/delta_i_c.npy")).float().to(self.device) | |
| with open("editings/styleclip/global_mapper_data/S_mean_std", "rb") as channels_statistics: | |
| _, s_std = pickle.load(channels_statistics) | |
| s_std = [torch.from_numpy(s_i).float().to(self.device) for s_i in s_std] | |
| with open("editings/styleclip/global_mapper_data/templates.txt", "r") as templates: | |
| text_prompt_templates = templates.readlines() | |
| global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates, device=self.device) | |
| return global_direction_calculator | |
| def get_styleclip_mapper_edits(self, start_w, factors, direction): | |
| latents_to_display = [] | |
| mapper_checkpoint_path = os.path.join( | |
| "pretrained_models/styleclip_mappers", | |
| f"{direction}.pt", | |
| ) | |
| ckpt = torch.load(mapper_checkpoint_path, map_location="cpu") | |
| opts = ckpt["opts"] | |
| styleclip_opts = argparse.Namespace( | |
| **{ | |
| "mapper_type": "LevelsMapper", | |
| "no_coarse_mapper": self.styleclip_directions[direction][0], | |
| "no_medium_mapper": self.styleclip_directions[direction][1], | |
| "no_fine_mapper": self.styleclip_directions[direction][2], | |
| "stylegan_size": 1024, | |
| "checkpoint_path": mapper_checkpoint_path, | |
| } | |
| ) | |
| opts.update(vars(styleclip_opts)) | |
| opts = argparse.Namespace(**opts) | |
| style_clip_net = StyleCLIPMapper(opts) | |
| style_clip_net.eval() | |
| style_clip_net.to(self.device) | |
| direction = style_clip_net.mapper(start_w) | |
| for factor in factors: | |
| edited_latent = start_w + factor * direction | |
| latents_to_display.append(edited_latent) | |
| return latents_to_display | |
| def get_styleclip_global_edits(self, start_s, factors, direction): | |
| latents_to_display = [] | |
| neutral_text, target_text, disentanglement = direction.split("_") | |
| disentanglement = float(disentanglement) | |
| directions = self.styleclip_global_editor.get_delta_s(neutral_text, target_text, disentanglement) | |
| factors = torch.tensor(factors).to(self.device).view(-1, 1) | |
| srart_ss, start_rgb = start_s | |
| edits_rgb = [] | |
| edits_ss = [] | |
| for i in range(26): | |
| if i in [1, 4, 7, 10, 13, 16, 19, 22, 25]: | |
| edits_rgb.append(directions[i].view(1, -1).repeat(len(factors), 1)) | |
| else: | |
| edits_ss.append(directions[i].view(1, -1).repeat(len(factors), 1)) | |
| edited_rgb = [] | |
| edited_ss = [] | |
| for orig, edit in zip(srart_ss, edits_ss): | |
| edited_ss.append(orig.repeat(len(factors), 1) + edit * factors.repeat(1, orig.size(1)) / 1.5) | |
| for orig, edit in zip(start_rgb, edits_rgb): | |
| edited_rgb.append(orig.repeat(len(factors), 1) + edit * factors.repeat(1, orig.size(1)) / 1.5) | |
| return edited_ss, edited_rgb | |
| def get_deltaedit_edits(self, start_s, factors, direction, original_image): | |
| latents_to_display = [] | |
| neutral_text, target_text, disentanglement = direction.split("_") | |
| disentanglement = float(disentanglement) | |
| factors = torch.tensor(factors).to(self.device).view(-1, 1) | |
| srart_ss, edited_rgb = start_s | |
| edits_ss = self.deltaedit_editor.get_delta_s(neutral_text, target_text, disentanglement, original_image, srart_ss) | |
| edited_rgb = [latent.repeat(len(factors), 1) for latent in edited_rgb] | |
| edited_ss = [] | |
| for orig, edit in zip(srart_ss, edits_ss): | |
| edited_ss.append(orig.repeat(len(factors), 1) + edit * factors.repeat(1, orig.size(1))) | |
| return edited_ss, edited_rgb | |
| def get_ganspace_edits(self, start_w, factors, direction): | |
| latents_to_display = [] | |
| for factor in factors: | |
| ganspace_direction = self.ganspace_directions[direction] | |
| edit_direction = list(ganspace_direction) | |
| edit_direction[-1] = factor | |
| edit_direction = tuple(edit_direction) | |
| new_w = ganspace.edit(start_w, self.ganspace_pca, [edit_direction]) | |
| latents_to_display.append(new_w) | |
| return latents_to_display | |
| def get_interface_gan_edits(self, start_w, factors, direction): | |
| latents_to_display = [] | |
| for factor in factors: | |
| tensor_direction = self.interfacegan_tensors[direction] | |
| edited_latent = start_w + factor / 2 * tensor_direction | |
| latents_to_display.append(edited_latent) | |
| return latents_to_display | |
| def get_stylespace_edits(self, start_s, factors, direction): | |
| edits = self.stylespace_directions[direction] | |
| start_stylespaces, start_stylespaces_rgb = start_s | |
| device = start_stylespaces[0].device | |
| latents_to_display = [] | |
| edited_latent = [ | |
| s.clone().repeat(len(factors), 1) | |
| for s in start_stylespaces | |
| ] | |
| factors = torch.tensor(factors).to(device) | |
| for layer_num, feat_num in edits: | |
| edited_latent[STYLESPACE_IDX[layer_num]][:, feat_num] += factors * 3 | |
| edited_stylespaces_rgb = [ | |
| rgb.repeat(len(factors), 1) for rgb in start_stylespaces_rgb | |
| ] | |
| return edited_latent, edited_stylespaces_rgb | |
| def get_fs_edits(self, w, factors, direction): | |
| path = self.fs_directions[direction] | |
| boundary = np.load(path) | |
| device = w.device | |
| bs = w.size(0) | |
| w_0 = w.cpu().numpy().reshape(bs, -1) | |
| boundary = boundary.reshape(1, -1).repeat(bs, 0) | |
| edits = [torch.tensor(w_0 + factor * boundary).view(bs, -1, 512).to(device) for factor in factors] | |
| return edits | |