| import torch | |
| import os | |
| from collections import defaultdict | |
| def parse_proxy(fname, scale): | |
| f = open(fname, 'r') | |
| layer_dict = {} | |
| for line in f: | |
| if 'proxy error' in line: | |
| line = line.rstrip() | |
| line = line[line.find('layer'):] | |
| proxy_error = float(line[line.find(':') + 1:]) | |
| layer_name = ' '.join(line.split(' ')[1:3]) | |
| layer_dict[layer_name] = {scale: proxy_error} | |
| return layer_dict | |
| total = None | |
| files = ['075', '080', '085', '090', '095', '100', '103', '105'] | |
| for key in files: | |
| res = parse_proxy(f'/work/albert/two_bit_quant/slurm_out/e8p_s{key}.log', key) | |
| if total is None: | |
| total = res | |
| else: | |
| for key in res: | |
| total[key].update(res[key]) | |
| hist = defaultdict(int) | |
| best_layer = {} | |
| for layer in total: | |
| best = float('inf') | |
| best_scale = None | |
| for scale in total[layer]: | |
| if total[layer][scale] < best: | |
| best = total[layer][scale] | |
| best_scale = scale | |
| best_layer[layer] = best_scale | |
| hist[best_scale] += 1 | |
| print(hist) | |
| exit() | |
| ckpt_path = '/work/albert/two_bit_quant/checkpoints' | |
| out_path = os.path.join(ckpt_path, 'e8p_best_scale') | |
| os.system(f'rm -rf {out_path}') | |
| os.system(f'mkdir {out_path}') | |
| os.system('cp {} {}'.format( | |
| os.path.join(ckpt_path, f'e8p_s{files[0]}', 'config.pt'), | |
| out_path)) | |
| for layer in best_layer: | |
| src = os.path.join(ckpt_path, f'e8p_s{best_layer[layer]}', '{}.pt'.format(layer.replace(' ', '_'))) | |
| tgt = out_path | |
| os.system(f'cp {src} {tgt}') | |