Spaces:
Runtime error
Runtime error
| from manipulate import Manipulator | |
| import tensorflow as tf | |
| import numpy as np | |
| import torch | |
| import clip | |
| from MapTS import GetBoundary,GetDt | |
| class StyleCLIP(): | |
| def __init__(self,dataset_name='ffhq'): | |
| print('load clip') | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model, preprocess = clip.load("ViT-B/32", device=device) | |
| self.LoadData(dataset_name) | |
| def LoadData(self, dataset_name): | |
| tf.keras.backend.clear_session() | |
| M=Manipulator(dataset_name=dataset_name) | |
| np.set_printoptions(suppress=True) | |
| fs3=np.load('./npy/'+dataset_name+'/fs3.npy') | |
| self.M=M | |
| self.fs3=fs3 | |
| w_plus=np.load('./data/'+dataset_name+'/w_plus.npy') | |
| self.M.dlatents=M.W2S(w_plus) | |
| if dataset_name=='ffhq': | |
| self.c_threshold=20 | |
| else: | |
| self.c_threshold=100 | |
| self.SetInitP() | |
| def SetInitP(self): | |
| self.M.alpha=[3] | |
| self.M.num_images=1 | |
| self.target='' | |
| self.neutral='' | |
| self.GetDt2() | |
| img_index=0 | |
| self.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.M.dlatents] | |
| def GetDt2(self): | |
| classnames=[self.target,self.neutral] | |
| dt=GetDt(classnames,self.model) | |
| self.dt=dt | |
| num_cs=[] | |
| betas=np.arange(0.1,0.3,0.01) | |
| for i in range(len(betas)): | |
| boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=betas[i]) | |
| print(betas[i]) | |
| num_cs.append(num_c) | |
| num_cs=np.array(num_cs) | |
| select=num_cs>self.c_threshold | |
| if sum(select)==0: | |
| self.beta=0.1 | |
| else: | |
| self.beta=betas[select][-1] | |
| def GetCode(self): | |
| boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=self.beta) | |
| codes=self.M.MSCode(self.M.dlatent_tmp,boundary_tmp2) | |
| return codes | |
| def GetImg(self): | |
| codes=self.GetCode() | |
| out=self.M.GenerateImg(codes) | |
| img=out[0,0] | |
| return img | |
| #%% | |
| if __name__ == "__main__": | |
| style_clip=StyleCLIP() | |
| self=style_clip | |