| |
|
| |
|
| |
|
| | import os |
| | import pickle |
| | import numpy as np |
| | from dnnlib import tflib |
| | import tensorflow as tf |
| |
|
| | import argparse |
| |
|
| | def LoadModel(dataset_name): |
| | |
| | tflib.init_tf() |
| | model_path='./model/' |
| | model_name=dataset_name+'.pkl' |
| | |
| | tmp=os.path.join(model_path,model_name) |
| | with open(tmp, 'rb') as f: |
| | _, _, Gs = pickle.load(f) |
| | return Gs |
| |
|
| | def lerp(a,b,t): |
| | return a + (b - a) * t |
| |
|
| | |
| | def SelectName(layer_name,suffix): |
| | if suffix==None: |
| | tmp1='add:0' in layer_name |
| | tmp2='shape=(?,' in layer_name |
| | tmp4='G_synthesis_1' in layer_name |
| | tmp= tmp1 and tmp2 and tmp4 |
| | else: |
| | tmp1=('/Conv0_up'+suffix) in layer_name |
| | tmp2=('/Conv1'+suffix) in layer_name |
| | tmp3=('4x4/Conv'+suffix) in layer_name |
| | tmp4='G_synthesis_1' in layer_name |
| | tmp5=('/ToRGB'+suffix) in layer_name |
| | tmp= (tmp1 or tmp2 or tmp3 or tmp5) and tmp4 |
| | return tmp |
| |
|
| |
|
| | def GetSNames(suffix): |
| | |
| | with tf.Session() as sess: |
| | op = sess.graph.get_operations() |
| | layers=[m.values() for m in op] |
| | |
| | |
| | select_layers=[] |
| | for layer in layers: |
| | layer_name=str(layer) |
| | if SelectName(layer_name,suffix): |
| | select_layers.append(layer[0]) |
| | return select_layers |
| |
|
| | def SelectName2(layer_name): |
| | tmp1='mod_bias' in layer_name |
| | tmp2='mod_weight' in layer_name |
| | tmp3='ToRGB' in layer_name |
| | |
| | tmp= (tmp1 or tmp2) and (not tmp3) |
| | return tmp |
| |
|
| | def GetKName(Gs): |
| | |
| | layers=[var for name, var in Gs.components.synthesis.vars.items()] |
| | |
| | select_layers=[] |
| | for layer in layers: |
| | layer_name=str(layer) |
| | if SelectName2(layer_name): |
| | select_layers.append(layer) |
| | return select_layers |
| |
|
| | def GetCode(Gs,random_state,num_img,num_once,dataset_name): |
| | rnd = np.random.RandomState(random_state) |
| | |
| | truncation_psi=0.7 |
| | truncation_cutoff=8 |
| | |
| | dlatent_avg=Gs.get_var('dlatent_avg') |
| | |
| | dlatents=np.zeros((num_img,512),dtype='float32') |
| | for i in range(int(num_img/num_once)): |
| | src_latents = rnd.randn(num_once, Gs.input_shape[1]) |
| | src_dlatents = Gs.components.mapping.run(src_latents, None) |
| | |
| | |
| | if truncation_psi is not None and truncation_cutoff is not None: |
| | layer_idx = np.arange(src_dlatents.shape[1])[np.newaxis, :, np.newaxis] |
| | ones = np.ones(layer_idx.shape, dtype=np.float32) |
| | coefs = np.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones) |
| | src_dlatents_np=lerp(dlatent_avg, src_dlatents, coefs) |
| | src_dlatents=src_dlatents_np[:,0,:].astype('float32') |
| | dlatents[(i*num_once):((i+1)*num_once),:]=src_dlatents |
| | print('get all z and w') |
| | |
| | tmp='./npy/'+dataset_name+'/W' |
| | np.save(tmp,dlatents) |
| |
|
| | |
| | def GetImg(Gs,num_img,num_once,dataset_name,save_name='images'): |
| | print('Generate Image') |
| | tmp='./npy/'+dataset_name+'/W.npy' |
| | dlatents=np.load(tmp) |
| | fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) |
| | |
| | all_images=[] |
| | for i in range(int(num_img/num_once)): |
| | print(i) |
| | images=[] |
| | for k in range(num_once): |
| | tmp=dlatents[i*num_once+k] |
| | tmp=tmp[None,None,:] |
| | tmp=np.tile(tmp,(1,Gs.components.synthesis.input_shape[1],1)) |
| | image2= Gs.components.synthesis.run(tmp, randomize_noise=False, output_transform=fmt) |
| | images.append(image2) |
| | |
| | images=np.concatenate(images) |
| | |
| | all_images.append(images) |
| | |
| | all_images=np.concatenate(all_images) |
| | |
| | tmp='./npy/'+dataset_name+'/'+save_name |
| | np.save(tmp,all_images) |
| |
|
| | def GetS(dataset_name,num_img): |
| | print('Generate S') |
| | tmp='./npy/'+dataset_name+'/W.npy' |
| | dlatents=np.load(tmp)[:num_img] |
| | |
| | with tf.Session() as sess: |
| | init = tf.global_variables_initializer() |
| | sess.run(init) |
| | |
| | Gs=LoadModel(dataset_name) |
| | Gs.print_layers() |
| | select_layers1=GetSNames(suffix=None) |
| | dlatents=dlatents[:,None,:] |
| | dlatents=np.tile(dlatents,(1,Gs.components.synthesis.input_shape[1],1)) |
| | |
| | all_s = sess.run( |
| | select_layers1, |
| | feed_dict={'G_synthesis_1/dlatents_in:0': dlatents}) |
| | |
| | layer_names=[layer.name for layer in select_layers1] |
| | save_tmp=[layer_names,all_s] |
| | return save_tmp |
| |
|
| | |
| |
|
| |
|
| | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False): |
| | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. |
| | Can be used as an output transformation for Network.run(). |
| | """ |
| | if nchw_to_nhwc: |
| | images = np.transpose(images, [0, 2, 3, 1]) |
| | |
| | scale = 255 / (drange[1] - drange[0]) |
| | images = images * scale + (0.5 - drange[0] * scale) |
| | |
| | np.clip(images, 0, 255, out=images) |
| | images=images.astype('uint8') |
| | return images |
| |
|
| |
|
| | def GetCodeMS(dlatents): |
| | m=[] |
| | std=[] |
| | for i in range(len(dlatents)): |
| | tmp= dlatents[i] |
| | tmp_mean=tmp.mean(axis=0) |
| | tmp_std=tmp.std(axis=0) |
| | m.append(tmp_mean) |
| | std.append(tmp_std) |
| | return m,std |
| |
|
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | |
| | parser = argparse.ArgumentParser(description='Process some integers.') |
| | |
| | parser.add_argument('--dataset_name',type=str,default='ffhq', |
| | help='name of dataset, for example, ffhq') |
| | parser.add_argument('--code_type',choices=['w','s','s_mean_std'],default='w') |
| | |
| | args = parser.parse_args() |
| | random_state=5 |
| | num_img=100_000 |
| | num_once=1_000 |
| | dataset_name=args.dataset_name |
| | |
| | if not os.path.isfile('./model/'+dataset_name+'.pkl'): |
| | url='https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/' |
| | name='stylegan2-'+dataset_name+'-config-f.pkl' |
| | os.system('wget ' +url+name + ' -P ./model/') |
| | os.system('mv ./model/'+name+' ./model/'+dataset_name+'.pkl') |
| | |
| | if not os.path.isdir('./npy/'+dataset_name): |
| | os.system('mkdir ./npy/'+dataset_name) |
| | |
| | if args.code_type=='w': |
| | Gs=LoadModel(dataset_name=dataset_name) |
| | GetCode(Gs,random_state,num_img,num_once,dataset_name) |
| | |
| | elif args.code_type=='s': |
| | save_name='S' |
| | save_tmp=GetS(dataset_name,num_img=2_000) |
| | tmp='./npy/'+dataset_name+'/'+save_name |
| | with open(tmp, "wb") as fp: |
| | pickle.dump(save_tmp, fp) |
| | |
| | elif args.code_type=='s_mean_std': |
| | save_tmp=GetS(dataset_name,num_img=num_img) |
| | dlatents=save_tmp[1] |
| | m,std=GetCodeMS(dlatents) |
| | save_tmp=[m,std] |
| | save_name='S_mean_std' |
| | tmp='./npy/'+dataset_name+'/'+save_name |
| | with open(tmp, "wb") as fp: |
| | pickle.dump(save_tmp, fp) |
| | |
| | |
| | |
| | |
| | |
| |
|