Spaces:
Sleeping
Sleeping
| import os | |
| from tqdm import tqdm | |
| import argparse | |
| import pickle | |
| from src.plotting.eval_matrix import matrix_plot, scatter_plot | |
| from src.utils.paths import get_path | |
| import matplotlib.pyplot as plt | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--input", type=str, required=False, default="scouting_PFNano_signals2/SVJ_hadronic_std/batch_eval/small_dataset") | |
| args = parser.parse_args() | |
| path = get_path(args.input, "results") | |
| def get_steps(config): | |
| if "ckpt_step" in config: | |
| return config["ckpt_step"] | |
| # else, config["load_model_weights"] looks like /.../.../step_xxxx_epoch_y.ckpt (fallback) | |
| return int(config["load_model_weights"].split("/")[-1].split("_")[1]) | |
| def get_short(network_config): | |
| if "transformer" in network_config.lower(): | |
| return "Transformer" | |
| if "lgatr" in network_config.lower(): | |
| return "LGATr" | |
| if "gatr" in network_config.lower(): | |
| return "GATr" | |
| return "Unknown" | |
| def get_model_details(path_to_eval): | |
| config = pickle.load(open(os.path.join(path_to_eval, "run_config.pkl"), "rb")) | |
| return config["num_parameters"], get_short(config["network_config"]), get_steps(config) | |
| models = sorted([x for x in os.listdir(path) if not (os.path.isfile(os.path.join(path, x)) or "AK8" in x)])# + ["AK8", "AK8_GenJets"] | |
| data = [get_model_details(os.path.join(path, model)) for model in models] + [(0, "AK8", 0), (0, "AK8_GenJets", 0)] | |
| models = models + ["AK8", "AK8_GenJets"] | |
| out_file_PR = os.path.join(get_path(args.input, "results"), "precision_recall_n_params.pdf") | |
| sz = 5 | |
| fig, ax = plt.subplots(3, len(models), figsize=(sz * len(models), sz * 3)) | |
| result_scatter = {} # e.g. Transformer -> [xarr, yarr, yarr1, yarr2] | |
| result_scatter_900_03 = {} | |
| result_by_step = {"900_03": {}, "700_07": {}} # Model+n_params -> [step, p, r, f1] | |
| def get_arch_name(n_params, net_short): | |
| if net_short == "Transformer": | |
| if n_params == 4674: | |
| return "Tr-2-16-4" | |
| elif n_params == 1201108: | |
| return "Tr" | |
| elif n_params == 1322274: | |
| return "Tr" | |
| elif n_params == 167394: | |
| return "Tr-5-64-4" | |
| if net_short == "LGATr": | |
| if n_params == 8424: | |
| return "LGATr-2-4-4" | |
| elif n_params == 1201108: | |
| return "LGATr" | |
| elif n_params == 156332: | |
| return "LGATr-3-16-16" | |
| if net_short == "GATr": | |
| if n_params == 6533: | |
| return "GATr-2-4-4" | |
| if n_params == 926041: | |
| return "GATr" | |
| if "AK8" in net_short: | |
| return net_short | |
| return None | |
| # n_params, P, R, f1 | |
| for i, model in tqdm(enumerate(models)): | |
| output_path = os.path.join(path, model, "count_matched_quarks") | |
| if not os.path.exists(os.path.join(output_path, "result.pkl")): | |
| continue | |
| result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb")) | |
| result_fakes = pickle.load(open(os.path.join(output_path, "result_fakes.pkl"), "rb")) | |
| result_bc = pickle.load(open(os.path.join(output_path, "result_bc.pkl"), "rb")) | |
| result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb")) | |
| #matrix_plot(result_PR, "Oranges", "Precision (N matched dark quarks / N predicted jets)", metric_comp_func = lambda r: r[0], ax=ax[0, i]) | |
| #matrix_plot(result_PR, "Reds", "Recall (N matched dark quarks / N dark quarks)", metric_comp_func = lambda r: r[1], ax=ax[1, i]) | |
| #matrix_plot(result_PR, "Purples", r"$F_1$ score", metric_comp_func = lambda r: 2 * r[0] * r[1] / (r[0] + r[1]), ax=ax[2, i]) | |
| arch = get_arch_name(data[i][0], data[i][1]) | |
| if arch is not None: | |
| if result_by_step["900_03"].get(arch) is None: | |
| for key in result_by_step: | |
| result_by_step[key][arch] = [[], [], [], []] | |
| pr = result_PR[900][20][0.3] | |
| result_by_step["900_03"][arch][0].append(data[i][2]) | |
| result_by_step["900_03"][arch][1].append(pr[0]) | |
| result_by_step["900_03"][arch][2].append(pr[1]) | |
| result_by_step["900_03"][arch][3].append(2 * pr[0] * pr[1] / (pr[0] + pr[1])) | |
| pr = result_PR[700][20][0.7] | |
| result_by_step["700_07"][arch][0].append(data[i][2]) | |
| result_by_step["700_07"][arch][1].append(pr[0]) | |
| result_by_step["700_07"][arch][2].append(pr[1]) | |
| result_by_step["700_07"][arch][3].append(2 * pr[0] * pr[1] / (pr[0] + pr[1])) | |
| if data[i][2] != 40000: | |
| continue | |
| ax[0, i].set_title(str(data[i][0]) + " " + data[i][1]) | |
| ax[1, i].set_title(str(data[i][0]) + " " + data[i][1]) | |
| ax[2, i].set_title(str(data[i][0]) + " " + data[i][1]) | |
| if data[i][1] not in result_scatter: | |
| result_scatter[data[i][1]] = [[], [], [], []] | |
| result_scatter_900_03[data[i][1]] = [[], [], [], []] | |
| result_scatter[data[i][1]][0].append(data[i][0]) | |
| pr = result_PR[700][20][0.7] | |
| pr_900_03 = result_PR[900][20][0.3] | |
| result_scatter[data[i][1]][3].append(2 * pr[0] * pr[1] / (pr[0] + pr[1])) | |
| result_scatter[data[i][1]][1].append(pr[0]) | |
| result_scatter[data[i][1]][2].append(pr[1]) | |
| result_scatter_900_03[data[i][1]][3].append(2 * pr_900_03[0] * pr_900_03[1] / (pr_900_03[0] + pr_900_03[1])) | |
| result_scatter_900_03[data[i][1]][1].append(pr_900_03[0]) | |
| result_scatter_900_03[data[i][1]][2].append(pr_900_03[1]) | |
| result_scatter_900_03[data[i][1]][0].append(data[i][0]) | |
| fig.tight_layout() | |
| fig.savefig(out_file_PR) | |
| print("Saved to", out_file_PR) | |
| fig_scatter, ax_scatter = plt.subplots(3, 1, figsize=(sz , sz * 3)) | |
| colors = { | |
| "Transformer": "green", | |
| "GATr": "blue", | |
| "LGATr": "red", | |
| } | |
| for key in result_scatter: | |
| scatter_plot(ax_scatter[0], result_scatter[key][0], result_scatter[key][1], key) | |
| scatter_plot(ax_scatter[1], result_scatter[key][0], result_scatter[key][2], key) | |
| scatter_plot(ax_scatter[2], result_scatter[key][0], result_scatter[key][3], key) | |
| ax_scatter[0].set_ylabel("Precision") | |
| ax_scatter[1].set_ylabel("Recall") | |
| ax_scatter[2].set_ylabel("F1 score") | |
| ax_scatter[0].set_xlabel("N params") | |
| ax_scatter[1].set_xlabel("N params") | |
| ax_scatter[2].set_xlabel("N params") | |
| ax_scatter[0].legend() | |
| ax_scatter[1].legend() | |
| ax_scatter[2].legend() | |
| ax_scatter[0].grid() | |
| ax_scatter[1].grid() | |
| ax_scatter[2].grid() | |
| ax_scatter[0].set_xscale("log") | |
| ax_scatter[1].set_xscale("log") | |
| ax_scatter[2].set_xscale("log") | |
| fig_scatter.tight_layout() | |
| fig_scatter.savefig(out_file_PR.replace(".pdf", "_scatter_700_07.pdf")) | |
| print("Saved to", out_file_PR.replace(".pdf", "_scatter_700_07.pdf")) | |
| fig_scatter, ax_scatter = plt.subplots(3, 1, figsize=(sz, sz*3)) | |
| for key in result_scatter_900_03: | |
| scatter_plot(ax_scatter[0], result_scatter_900_03[key][0], result_scatter_900_03[key][1], key) | |
| scatter_plot(ax_scatter[1], result_scatter_900_03[key][0], result_scatter_900_03[key][2], key) | |
| scatter_plot(ax_scatter[2], result_scatter_900_03[key][0], result_scatter_900_03[key][3], key) | |
| ax_scatter[0].set_ylabel("Precision") | |
| ax_scatter[1].set_ylabel("Recall") | |
| ax_scatter[2].set_ylabel("F1 score") | |
| ax_scatter[0].set_xlabel("N params") | |
| ax_scatter[1].set_xlabel("N params") | |
| ax_scatter[2].set_xlabel("N params") | |
| ax_scatter[0].legend() | |
| ax_scatter[1].legend() | |
| ax_scatter[2].legend() | |
| ax_scatter[0].grid() | |
| ax_scatter[1].grid() | |
| ax_scatter[2].grid() | |
| ax_scatter[0].set_xscale("log") | |
| ax_scatter[1].set_xscale("log") | |
| ax_scatter[2].set_xscale("log") | |
| fig_scatter.tight_layout() | |
| fig_scatter.savefig(out_file_PR.replace(".pdf", "_scatter_900_03.pdf")) | |
| print("Saved to", out_file_PR.replace(".pdf", "_scatter_900_03.pdf")) | |
| fig_scatter, ax_scatter = plt.subplots(3, 2, figsize=(sz*2, sz*3)) | |
| fig_params_paper, ax_params_paper = plt.subplots(1, 2, figsize=(sz, sz*1.5)) | |
| for i, key in enumerate(sorted(list(result_by_step.keys()))): | |
| for model in result_by_step[key]: | |
| #scatter_plot(ax_scatter[], result_scatter_900_03[key][0], result_scatter_900_03[key][1], key) | |
| #scatter_plot(ax_scatter[1], result_scatter_900_03[key][0], result_scatter_900_03[key][2], key) | |
| #scatter_plot(ax_scatter[2], result_scatter_900_03[key][0], result_scatter_900_03[key][3], key) | |
| if "AK8" in model: | |
| # put a horizontal dotted line instead of a scatterplot, as there is only one dot | |
| colors = {"AK8": "gray", "AK8_GenJets": "black"} | |
| ax_scatter[0, i].axhline(result_by_step[key][model][1][0], label=model, color=colors[model], linestyle="--") | |
| ax_scatter[1, i].axhline(result_by_step[key][model][2][0], label=model, color=colors[model], linestyle="--") | |
| ax_scatter[2, i].axhline(result_by_step[key][model][3][0], label=model, color=colors[model], linestyle="--") | |
| else: | |
| scatter_plot(ax_scatter[0, i], result_by_step[key][model][0], result_by_step[key][model][1], model) | |
| scatter_plot(ax_scatter[1, i], result_by_step[key][model][0], result_by_step[key][model][2], model) | |
| scatter_plot(ax_scatter[2, i], result_by_step[key][model][0], result_by_step[key][model][3], model) | |
| ax_scatter[0, i].set_title(key) | |
| ax_scatter[1, i].set_title(key) | |
| ax_scatter[2, i].set_title(key) | |
| ax_scatter[0, i].set_ylabel("Precision") | |
| ax_scatter[1, i].set_ylabel("Recall") | |
| ax_scatter[2, i].set_ylabel("F_1 score") | |
| ax_scatter[0, i].set_xlabel("training steps") | |
| ax_scatter[1, i].set_xlabel("training steps") | |
| ax_scatter[2, i].set_xlabel("training steps") | |
| ax_scatter[0, i].legend() | |
| ax_scatter[1, i].legend() | |
| ax_scatter[2, i].legend() | |
| ax_scatter[0, i].grid() | |
| ax_scatter[1, i].grid() | |
| ax_scatter[2, i].grid() | |
| ax_scatter[0, i].set_xscale("log") | |
| ax_scatter[1, i].set_xscale("log") | |
| ax_scatter[2, i].set_xscale("log") | |
| fig_scatter.tight_layout() | |
| fig_scatter.savefig(out_file_PR.replace(".pdf", "_by_step.pdf")) | |
| print("Saved to", out_file_PR.replace(".pdf", "_by_step.pdf")) | |