Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| import pandas as pd | |
| import torch.nn.functional as F | |
| from .util import load_img | |
| from configs import path_csv_ffhq_attritube | |
| class GenePoolFactory(object): | |
| def __init__(self, root_ffhq, device, mean_latent, max_sample=100): | |
| self.device = device | |
| self.mean_latent = mean_latent | |
| self.root_ffhq = root_ffhq | |
| self.max_sample = max_sample | |
| self.pools = {} | |
| path_ffhq_attributes = path_csv_ffhq_attritube | |
| self.df = pd.read_csv(path_ffhq_attributes) | |
| self.df.replace('Male', 'male', inplace=True) | |
| self.df.replace('Female', 'female', inplace=True) | |
| def __call__(self, encoder, w2sub34, age, gender, race): | |
| keyname = f'{age}-{gender}-{race}' | |
| if keyname in self.pools.keys(): | |
| return self.pools[keyname] | |
| elif self.root_ffhq is not None: | |
| result = self.df.query(f'gender == "{gender}" and age == "{age}" and race == "{race}"') | |
| result = result[['file_id']].values | |
| tmp = [] | |
| random.shuffle(result) | |
| for fid in result[:self.max_sample]: | |
| filename = format(int(fid[0]), '05d') + ".png" | |
| img = load_img(os.path.join(self.root_ffhq, filename)) | |
| img = img.to(self.device) | |
| w18_1 = encoder(F.interpolate(img, size=(256, 256))) + self.mean_latent | |
| mu, var, sub34_1 = w2sub34(w18_1) | |
| tmp.append((mu.cpu(), var.cpu())) | |
| self.pools[keyname] = tmp | |
| return self.pools[keyname] | |
| else: | |
| return [] | |