import json import os from statistics import mean import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap # EvalPerf result file name format: {MODEL}_temp_{TEMP}_ep-{TYPE}_results.json # Draw a heatmap of pairwise comparison of models # each pair of models is compared using the same set of passing tasks def main(result_dir: str): assert os.path.isdir(result_dir), f"{result_dir} is not a directory." model2task2dps = {} model2task2dps_norm = {} model_list = [] model_e2e_dps = [] for result_json in os.listdir(result_dir): if not result_json.endswith(".json"): continue result_json_path = os.path.join(result_dir, result_json) assert "_temp_0.2_" in result_json, f"Invalid result file name: {result_json}" model_id = result_json.split("_temp_0.2_")[0] if model_id.endswith("-instruct") and not model_id.endswith(" perf-instruct"): model_id = model_id[: -len("-instruct")] model_id += " :: default" if "::" not in model_id: model_id += " :: default" print(f"Processing {model_id}") with open(result_json_path) as f: results = json.load(f) task2dps = {} task2dps_norm = {} for task_id, result in results.items(): if "scores" in result and result["scores"] is not None: task2dps[task_id] = result["scores"]["max"] task2dps_norm[task_id] = result["norm_scores"]["max"] if "dps" in result and result["dps"] is not None: task2dps[task_id] = max(result["dps"]) task2dps_norm[task_id] = max(result["dps_norm"]) model2task2dps[model_id] = task2dps model2task2dps_norm[model_id] = task2dps_norm model_list.append(model_id) model_e2e_dps.append(mean(task2dps.values())) # sort model list by dps score model_list, model_e2e_dps = zip( *sorted(zip(model_list, model_e2e_dps), key=lambda x: x[1], reverse=True) ) # model_list = model_list[:32] fig, ax = plt.subplots(figsize=(30, 25)) score_matrix = [] for i, model_x in enumerate(model_list): score_list = [] task2dps_x = model2task2dps[model_x] for j, model_y in enumerate(model_list): if j <= i: score_list.append((0, 0)) continue task2dps_y = model2task2dps[model_y] common_tasks = set(task2dps_x.keys()) & set(task2dps_y.keys()) if len(common_tasks) == 0: score_list.append(None) print( f"[Warning] no common passing set between {model_x} and {model_y}" ) continue dps_x = mean([task2dps_x[task_id] for task_id in common_tasks]) dps_y = mean([task2dps_y[task_id] for task_id in common_tasks]) score_list.append((dps_x, dps_y)) text = f"{round(dps_x)}" if dps_x - dps_y >= 1: text += f"\n+{dps_x - dps_y:.1f}" elif dps_x - dps_y <= -1: text += f"\n-{dps_y - dps_x:.1f}" ax.text( j, i, text, va="center", ha="center", color="green" if dps_x > dps_y else "red", ) score_matrix.append(score_list) # print(score_matrix) score_matrix_diff = [ [None if score is None else score[0] - score[1] for score in score_list] for score_list in score_matrix ] cmap = LinearSegmentedColormap.from_list("rg", ["r", "w", "lime"], N=256) cax = ax.matshow(score_matrix_diff, cmap=cmap) cax.set_clim(-15, 15) fig.colorbar(cax) ax.set_xticks(range(len(model_list))) ax.set_yticks(range(len(model_list))) ax.set_xticklabels(model_list, rotation=45, ha="left", rotation_mode="anchor") ax.set_yticklabels(model_list) # save fig plt.savefig("pairwise_heatmap.png", dpi=120, bbox_inches="tight") plt.savefig("pairwise_heatmap.pdf", bbox_inches="tight") if __name__ == "__main__": from fire import Fire Fire(main)