| import torch |
| import os |
| from collections import defaultdict |
|
|
|
|
| def parse_proxy(fname, scale): |
| f = open(fname, 'r') |
| layer_dict = {} |
| for line in f: |
| if 'proxy error' in line: |
| line = line.rstrip() |
| line = line[line.find('layer'):] |
| proxy_error = float(line[line.find(':') + 1:]) |
| layer_name = ' '.join(line.split(' ')[1:3]) |
| layer_dict[layer_name] = {scale: proxy_error} |
| return layer_dict |
|
|
| total = None |
| files = ['075', '080', '085', '090', '095', '100', '103', '105'] |
| for key in files: |
| res = parse_proxy(f'/work/albert/two_bit_quant/slurm_out/e8p_s{key}.log', key) |
| if total is None: |
| total = res |
| else: |
| for key in res: |
| total[key].update(res[key]) |
| |
| hist = defaultdict(int) |
| best_layer = {} |
| for layer in total: |
| best = float('inf') |
| best_scale = None |
| for scale in total[layer]: |
| if total[layer][scale] < best: |
| best = total[layer][scale] |
| best_scale = scale |
| best_layer[layer] = best_scale |
| hist[best_scale] += 1 |
|
|
| print(hist) |
| exit() |
| |
| ckpt_path = '/work/albert/two_bit_quant/checkpoints' |
| out_path = os.path.join(ckpt_path, 'e8p_best_scale') |
| os.system(f'rm -rf {out_path}') |
| os.system(f'mkdir {out_path}') |
|
|
| os.system('cp {} {}'.format( |
| os.path.join(ckpt_path, f'e8p_s{files[0]}', 'config.pt'), |
| out_path)) |
|
|
| for layer in best_layer: |
| src = os.path.join(ckpt_path, f'e8p_s{best_layer[layer]}', '{}.pt'.format(layer.replace(' ', '_'))) |
| tgt = out_path |
| os.system(f'cp {src} {tgt}') |
|
|
|
|