Spaces:
Runtime error
Runtime error
| from editings.styleclip_directions.styleclip_mapper_network import LevelsMapper | |
| import torch | |
| import csv | |
| from options import Settings | |
| import os | |
| class Options(): | |
| def __init__(self, no_coarse_mapper, no_medium_mapper, no_fine_mapper) -> None: | |
| self.no_coarse_mapper = no_coarse_mapper | |
| self.no_medium_mapper = no_medium_mapper | |
| self.no_fine_mapper = no_fine_mapper | |
| class StyleClip(): | |
| def __init__(self) -> None: | |
| self.styleclip_mapping_configs = {} | |
| with open(os.path.join(Settings.styleclip_settings, 'styleclip_mapping_configs.csv'), "r") as f: | |
| reader = csv.reader(f) | |
| for row in reader: | |
| key = row.pop(0) | |
| self.styleclip_mapping_configs[key] = list(map(lambda x: True if x == "True" else False, row)) | |
| def edit(self, latent, cfg): | |
| with torch.no_grad(): | |
| if cfg.type == 'mapper': | |
| mapper = self.build_mapper(cfg.edit) | |
| return latent + cfg.strength * mapper(latent) | |
| if cfg.type == 'global': | |
| return latent + 10 * torch.load(os.path.join(Settings.styleclip_global_directions, 'makeup.pt')) | |
| # def load_global_direction(self, editname): | |
| # pass | |
| def build_mapper(self, editname): | |
| try: # Check if loaded | |
| mapper = getattr(self, f"{editname}_mapper") | |
| except: | |
| opts = Options(*self.styleclip_mapping_configs[editname]) | |
| mapper = LevelsMapper(opts) | |
| ckpt = torch.load(os.path.join(Settings.styleclip_mapper_directions, f'{editname}.pt')) | |
| mapper.load_state_dict(ckpt, strict=True) | |
| mapper.to(device=Settings.device) | |
| for param in mapper.parameters(): | |
| param.requires_grad = False | |
| mapper.eval() | |
| setattr(self, f"{editname}_mapper", mapper) | |
| return mapper |