Spaces:
Runtime error
Runtime error
| import os | |
| import pickle | |
| import numpy as np | |
| from dnnlib import tflib | |
| import tensorflow as tf | |
| import argparse | |
| def LoadModel(dataset_name): | |
| # Initialize TensorFlow. | |
| tflib.init_tf() | |
| model_path='./model/' | |
| model_name=dataset_name+'.pkl' | |
| tmp=os.path.join(model_path,model_name) | |
| with open(tmp, 'rb') as f: | |
| _, _, Gs = pickle.load(f) | |
| return Gs | |
| def lerp(a,b,t): | |
| return a + (b - a) * t | |
| #stylegan-ada | |
| def SelectName(layer_name,suffix): | |
| if suffix==None: | |
| tmp1='add:0' in layer_name | |
| tmp2='shape=(?,' in layer_name | |
| tmp4='G_synthesis_1' in layer_name | |
| tmp= tmp1 and tmp2 and tmp4 | |
| else: | |
| tmp1=('/Conv0_up'+suffix) in layer_name | |
| tmp2=('/Conv1'+suffix) in layer_name | |
| tmp3=('4x4/Conv'+suffix) in layer_name | |
| tmp4='G_synthesis_1' in layer_name | |
| tmp5=('/ToRGB'+suffix) in layer_name | |
| tmp= (tmp1 or tmp2 or tmp3 or tmp5) and tmp4 | |
| return tmp | |
| def GetSNames(suffix): | |
| #get style tensor name | |
| with tf.Session() as sess: | |
| op = sess.graph.get_operations() | |
| layers=[m.values() for m in op] | |
| select_layers=[] | |
| for layer in layers: | |
| layer_name=str(layer) | |
| if SelectName(layer_name,suffix): | |
| select_layers.append(layer[0]) | |
| return select_layers | |
| def SelectName2(layer_name): | |
| tmp1='mod_bias' in layer_name | |
| tmp2='mod_weight' in layer_name | |
| tmp3='ToRGB' in layer_name | |
| tmp= (tmp1 or tmp2) and (not tmp3) | |
| return tmp | |
| def GetKName(Gs): | |
| layers=[var for name, var in Gs.components.synthesis.vars.items()] | |
| select_layers=[] | |
| for layer in layers: | |
| layer_name=str(layer) | |
| if SelectName2(layer_name): | |
| select_layers.append(layer) | |
| return select_layers | |
| def GetCode(Gs,random_state,num_img,num_once,dataset_name): | |
| rnd = np.random.RandomState(random_state) #5 | |
| truncation_psi=0.7 | |
| truncation_cutoff=8 | |
| dlatent_avg=Gs.get_var('dlatent_avg') | |
| dlatents=np.zeros((num_img,512),dtype='float32') | |
| for i in range(int(num_img/num_once)): | |
| src_latents = rnd.randn(num_once, Gs.input_shape[1]) | |
| src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component] | |
| # Apply truncation trick. | |
| if truncation_psi is not None and truncation_cutoff is not None: | |
| layer_idx = np.arange(src_dlatents.shape[1])[np.newaxis, :, np.newaxis] | |
| ones = np.ones(layer_idx.shape, dtype=np.float32) | |
| coefs = np.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones) | |
| src_dlatents_np=lerp(dlatent_avg, src_dlatents, coefs) | |
| src_dlatents=src_dlatents_np[:,0,:].astype('float32') | |
| dlatents[(i*num_once):((i+1)*num_once),:]=src_dlatents | |
| print('get all z and w') | |
| tmp='./npy/'+dataset_name+'/W' | |
| np.save(tmp,dlatents) | |
| def GetImg(Gs,num_img,num_once,dataset_name,save_name='images'): | |
| print('Generate Image') | |
| tmp='./npy/'+dataset_name+'/W.npy' | |
| dlatents=np.load(tmp) | |
| fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) | |
| all_images=[] | |
| for i in range(int(num_img/num_once)): | |
| print(i) | |
| images=[] | |
| for k in range(num_once): | |
| tmp=dlatents[i*num_once+k] | |
| tmp=tmp[None,None,:] | |
| tmp=np.tile(tmp,(1,Gs.components.synthesis.input_shape[1],1)) | |
| image2= Gs.components.synthesis.run(tmp, randomize_noise=False, output_transform=fmt) | |
| images.append(image2) | |
| images=np.concatenate(images) | |
| all_images.append(images) | |
| all_images=np.concatenate(all_images) | |
| tmp='./npy/'+dataset_name+'/'+save_name | |
| np.save(tmp,all_images) | |
| def GetS(dataset_name,num_img): | |
| print('Generate S') | |
| tmp='./npy/'+dataset_name+'/W.npy' | |
| dlatents=np.load(tmp)[:num_img] | |
| with tf.Session() as sess: | |
| init = tf.global_variables_initializer() | |
| sess.run(init) | |
| Gs=LoadModel(dataset_name) | |
| Gs.print_layers() #for ada | |
| select_layers1=GetSNames(suffix=None) #None,'/mul_1:0','/mod_weight/read:0','/MatMul:0' | |
| dlatents=dlatents[:,None,:] | |
| dlatents=np.tile(dlatents,(1,Gs.components.synthesis.input_shape[1],1)) | |
| all_s = sess.run( | |
| select_layers1, | |
| feed_dict={'G_synthesis_1/dlatents_in:0': dlatents}) | |
| layer_names=[layer.name for layer in select_layers1] | |
| save_tmp=[layer_names,all_s] | |
| return save_tmp | |
| def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False): | |
| """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. | |
| Can be used as an output transformation for Network.run(). | |
| """ | |
| if nchw_to_nhwc: | |
| images = np.transpose(images, [0, 2, 3, 1]) | |
| scale = 255 / (drange[1] - drange[0]) | |
| images = images * scale + (0.5 - drange[0] * scale) | |
| np.clip(images, 0, 255, out=images) | |
| images=images.astype('uint8') | |
| return images | |
| def GetCodeMS(dlatents): | |
| m=[] | |
| std=[] | |
| for i in range(len(dlatents)): | |
| tmp= dlatents[i] | |
| tmp_mean=tmp.mean(axis=0) | |
| tmp_std=tmp.std(axis=0) | |
| m.append(tmp_mean) | |
| std.append(tmp_std) | |
| return m,std | |
| #%% | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description='Process some integers.') | |
| parser.add_argument('--dataset_name',type=str,default='ffhq', | |
| help='name of dataset, for example, ffhq') | |
| parser.add_argument('--code_type',choices=['w','s','s_mean_std'],default='w') | |
| args = parser.parse_args() | |
| random_state=5 | |
| num_img=100_000 | |
| num_once=1_000 | |
| dataset_name=args.dataset_name | |
| if not os.path.isfile('./model/'+dataset_name+'.pkl'): | |
| url='https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/' | |
| name='stylegan2-'+dataset_name+'-config-f.pkl' | |
| os.system('wget ' +url+name + ' -P ./model/') | |
| os.system('mv ./model/'+name+' ./model/'+dataset_name+'.pkl') | |
| if not os.path.isdir('./npy/'+dataset_name): | |
| os.system('mkdir ./npy/'+dataset_name) | |
| if args.code_type=='w': | |
| Gs=LoadModel(dataset_name=dataset_name) | |
| GetCode(Gs,random_state,num_img,num_once,dataset_name) | |
| # GetImg(Gs,num_img=num_img,num_once=num_once,dataset_name=dataset_name,save_name='images_100K') #no need | |
| elif args.code_type=='s': | |
| save_name='S' | |
| save_tmp=GetS(dataset_name,num_img=2_000) | |
| tmp='./npy/'+dataset_name+'/'+save_name | |
| with open(tmp, "wb") as fp: | |
| pickle.dump(save_tmp, fp) | |
| elif args.code_type=='s_mean_std': | |
| save_tmp=GetS(dataset_name,num_img=num_img) | |
| dlatents=save_tmp[1] | |
| m,std=GetCodeMS(dlatents) | |
| save_tmp=[m,std] | |
| save_name='S_mean_std' | |
| tmp='./npy/'+dataset_name+'/'+save_name | |
| with open(tmp, "wb") as fp: | |
| pickle.dump(save_tmp, fp) | |