| |
|
| |
|
| | from manipulate import Manipulator |
| | import tensorflow as tf |
| | import numpy as np |
| | import torch |
| | import clip |
| | from MapTS import GetBoundary,GetDt |
| |
|
| | class StyleCLIP(): |
| | |
| | def __init__(self,dataset_name='ffhq'): |
| | print('load clip') |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.model, preprocess = clip.load("ViT-B/32", device=device) |
| | self.LoadData(dataset_name) |
| | |
| | def LoadData(self, dataset_name): |
| | tf.keras.backend.clear_session() |
| | M=Manipulator(dataset_name=dataset_name) |
| | np.set_printoptions(suppress=True) |
| | fs3=np.load('./npy/'+dataset_name+'/fs3.npy') |
| | |
| | self.M=M |
| | self.fs3=fs3 |
| | |
| | w_plus=np.load('./data/'+dataset_name+'/w_plus.npy') |
| | self.M.dlatents=M.W2S(w_plus) |
| | |
| | if dataset_name=='ffhq': |
| | self.c_threshold=20 |
| | else: |
| | self.c_threshold=100 |
| | self.SetInitP() |
| | |
| | def SetInitP(self): |
| | self.M.alpha=[3] |
| | self.M.num_images=1 |
| | |
| | self.target='' |
| | self.neutral='' |
| | self.GetDt2() |
| | img_index=0 |
| | self.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.M.dlatents] |
| | |
| | |
| | def GetDt2(self): |
| | classnames=[self.target,self.neutral] |
| | dt=GetDt(classnames,self.model) |
| | |
| | self.dt=dt |
| | num_cs=[] |
| | betas=np.arange(0.1,0.3,0.01) |
| | for i in range(len(betas)): |
| | boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=betas[i]) |
| | print(betas[i]) |
| | num_cs.append(num_c) |
| | |
| | num_cs=np.array(num_cs) |
| | select=num_cs>self.c_threshold |
| | |
| | if sum(select)==0: |
| | self.beta=0.1 |
| | else: |
| | self.beta=betas[select][-1] |
| | |
| | |
| | def GetCode(self): |
| | boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=self.beta) |
| | codes=self.M.MSCode(self.M.dlatent_tmp,boundary_tmp2) |
| | return codes |
| | |
| | def GetImg(self): |
| | |
| | codes=self.GetCode() |
| | out=self.M.GenerateImg(codes) |
| | img=out[0,0] |
| | return img |
| | |
| |
|
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | style_clip=StyleCLIP() |
| | self=style_clip |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|