Spaces:
Running
Running
| import os, sys, glob | |
| # full_lst = glob.glob('diff_models_synth128*') | |
| # full_lst = glob.glob('diff_models_synth32*') | |
| # full_lst = glob.glob('diff_models_synth32_3_rand16*') | |
| # full_lst = glob.glob('diff_models_synth_rand_16_trans_lr_1e-5_long_Lsimple') | |
| full_lst = glob.glob(sys.argv[1]) | |
| top_p = -1.0 if len(sys.argv) < 2 else sys.argv[2] | |
| print(f'top_p = {top_p}') | |
| pattern_ = 'model' if len(sys.argv) < 3 else sys.argv[3] | |
| print(f'pattern_ = {pattern_}', sys.argv[3]) | |
| # print(full_lst) | |
| output_lst = [] | |
| for lst in full_lst: | |
| print(lst) | |
| try: | |
| tgt = sorted(glob.glob(f"{lst}/{pattern_}*pt"))[-1] | |
| lst = os.path.split(lst)[1] | |
| print(lst) | |
| num = 1 | |
| except: | |
| continue | |
| model_arch_ = lst.split('_')[5-num] | |
| model_arch = 'conv-unet' if 'conv-unet' in lst else 'transformer' | |
| mode = 'image' if ('conv' in model_arch ) else 'text' #or '1d-unet' in model_arch_ | |
| print(mode, model_arch_) | |
| dim_ =lst.split('_')[4-num] | |
| # diffusion_steps= 4000 | |
| # noise_schedule = 'cosine' | |
| # dim = dim_.split('rand')[1] | |
| if 'synth' in lst: | |
| modality = 'synth' | |
| elif 'pos' in lst: | |
| modality = 'pos' | |
| elif 'image' in lst: | |
| modality = 'image' | |
| elif 'roc' in lst: | |
| modality = 'roc' | |
| elif 'e2e-tgt' in lst: | |
| modality = 'e2e-tgt' | |
| elif 'simple-wiki' in lst: | |
| modality = 'simple-wiki' | |
| elif 'book' in lst: | |
| modality = 'book' | |
| elif 'yelp' in lst: | |
| modality = 'yelp' | |
| elif 'commonGen' in lst: | |
| modality = 'commonGen' | |
| elif 'e2e' in lst: | |
| modality = 'e2e' | |
| if 'synth32' in lst: | |
| kk = 32 | |
| elif 'synth128' in lst: | |
| kk = 128 | |
| try: | |
| diffusion_steps = int(lst.split('_')[7-num]) | |
| print(diffusion_steps) | |
| except: | |
| diffusion_steps = 4000 | |
| try: | |
| noise_schedule = lst.split('_')[8-num] | |
| assert noise_schedule in ['cosine', 'linear'] | |
| print(noise_schedule) | |
| except: | |
| noise_schedule = 'cosine' | |
| try: | |
| dim = int(dim_.split('rand')[1]) | |
| except: | |
| dim =lst.split('_')[4-num] | |
| try: | |
| print(len(lst.split('_'))) | |
| num_channels = int(lst.split('_')[-1].split('h')[1]) | |
| except: | |
| num_channels = 128 | |
| print(tgt, model_arch, dim, num_channels) | |
| # out_dir = 'diffusion_lm/improved_diffusion/out_gen_large_nucleus' | |
| # num_samples = 512 | |
| # out_dir = 'diffusion_lm/improved_diffusion/out_gen_v2_nucleus' | |
| out_dir = 'generation_outputs' | |
| num_samples = 50 | |
| if modality == 'e2e': | |
| num_samples = 547 | |
| COMMAND = f'python scripts/{mode}_sample.py ' \ | |
| f'--model_path {tgt} --batch_size 50 --num_samples {num_samples} --top_p {top_p} ' \ | |
| f'--out_dir {out_dir} ' | |
| print(COMMAND) | |
| # os.system(COMMAND) | |
| # shape_str = "x".join([str(x) for x in arr.shape]) | |
| model_base_name = os.path.basename(os.path.split(tgt)[0]) + f'.{os.path.split(tgt)[1]}' | |
| if modality == 'e2e-tgt' or modality == 'e2e': | |
| out_path2 = os.path.join(out_dir, f"{model_base_name}.samples_{top_p}.json") | |
| else: | |
| out_path2 = os.path.join(out_dir, f"{model_base_name}.samples_{top_p}.txt") | |
| output_cands = glob.glob(out_path2) | |
| print(out_path2, output_cands) | |
| if len(output_cands) > 0: | |
| out_path2 = glob.glob(out_path2)[0] | |
| else: | |
| os.system(COMMAND) | |
| out_path2 = glob.glob(out_path2)[0] | |
| output_lst.append(out_path2) | |
| if modality == 'pos': | |
| model_name_path = 'predictability/diff_models/pos_e=15_b=20_m=gpt2_wikitext-103-raw-v1_s=102' | |
| elif modality == 'synth': | |
| if kk == 128: | |
| model_name_path = 'predictability/diff_models/synth_e=15_b=10_m=gpt2_wikitext-103-raw-v1_None' | |
| else: | |
| model_name_path = 'predictability/diff_models/synth_e=15_b=20_m=gpt2_wikitext-103-raw-v1_None' | |
| elif modality == 'e2e-tgt': | |
| model_name_path = "predictability/diff_models/e2e-tgt_e=15_b=20_m=gpt2_wikitext-103-raw-v1_101_None" | |
| elif modality == 'roc': | |
| model_name_path = "predictability/diff_models/roc_e=6_b=10_m=gpt2_wikitext-103-raw-v1_101_wp_pad_v1" | |
| elif modality == 'e2e': | |
| COMMAND1 = f"python diffusion_lm/e2e_data/mbr.py {out_path2}" | |
| os.system(COMMAND1) | |
| COMMAND2 = f"python e2e-metrics/measure_scores.py " \ | |
| f"diffusion_lm/improved_diffusion/out_gen_v2_dropout2/1_valid_gold " \ | |
| f"{out_path2}.clean -p -t -H > {os.path.join(os.path.split(tgt)[0], 'e2e_valid_eval.txt')}" | |
| print(COMMAND2) | |
| os.system(COMMAND2) | |
| continue | |
| else: | |
| print('not trained a AR model yet... only look at the output plz.') | |
| continue | |
| COMMAND = f"python scripts/ppl_under_ar.py " \ | |
| f"--model_path {tgt} " \ | |
| f"--modality {modality} --experiment random " \ | |
| f"--model_name_or_path {model_name_path} " \ | |
| f"--input_text {out_path2} --mode eval" | |
| print(COMMAND) | |
| print() | |
| os.system(COMMAND) | |
| print('output lists:') | |
| print("\n".join(output_lst)) |