| from valids import parser, main as valids_main | |
| import os.path as osp | |
| args = parser.parse_args() | |
| args.target = "valid_accuracy" | |
| args.best_biggest = True | |
| args.best = True | |
| args.last = 0 | |
| args.path_contains = None | |
| res = valids_main(args, print_output=False) | |
| grouped = {} | |
| for k, v in res.items(): | |
| k = osp.dirname(k) | |
| run = osp.dirname(k) | |
| task = osp.basename(k) | |
| val = v["valid_accuracy"] | |
| if run not in grouped: | |
| grouped[run] = {} | |
| grouped[run][task] = val | |
| for run, tasks in grouped.items(): | |
| print(run) | |
| avg = sum(float(v) for v in tasks.values()) / len(tasks) | |
| avg_norte = sum(float(v) for k,v in tasks.items() if k != 'rte') / (len(tasks) -1) | |
| try: | |
| print(f"{tasks['cola']}\t{tasks['qnli']}\t{tasks['mrpc']}\t{tasks['rte']}\t{tasks['sst_2']}\t{avg:.2f}\t{avg_norte:.2f}") | |
| except: | |
| print(tasks) | |
| print() | |