Spaces:
Sleeping
Sleeping
| import dgl | |
| import torch | |
| import os | |
| from alembic.command import current | |
| from sklearn.cluster import DBSCAN, HDBSCAN | |
| from torch_scatter import scatter_max, scatter_add, scatter_mean | |
| import numpy as np | |
| from src.dataset.functions_data import CachedIndexList, spherical_to_cartesian | |
| import matplotlib.pyplot as plt | |
| from scipy.optimize import linear_sum_assignment | |
| import pandas as pd | |
| import wandb | |
| from src.utils.inference.per_particle_metrics import plot_event | |
| import random | |
| import string | |
| def generate_random_string(length): | |
| letters = string.ascii_letters + string.digits | |
| return "".join(random.choice(letters) for i in range(length)) | |
| def create_and_store_graph_output( | |
| batch_g, | |
| model_output, | |
| y, | |
| local_rank, | |
| step, | |
| epoch, | |
| path_save, | |
| store=False, | |
| predict=False, | |
| tracking=False, | |
| e_corr=None, | |
| shap_vals=None, | |
| ec_x=None, # ec_x: "global" features (what gets inputted into the final deep neural network head) for energy correction | |
| tracks=False, | |
| store_epoch=False, | |
| total_number_events=0, | |
| pred_pos=None, | |
| pred_ref_pt=None, | |
| use_gt_clusters=False, | |
| pids_neutral=None, | |
| pids_charged=None, | |
| pred_pid=None, | |
| pred_xyz_track=None, | |
| number_of_fakes=None | |
| ): | |
| number_of_showers_total = 0 | |
| number_of_showers_total1 = 0 | |
| number_of_fake_showers_total1 = 0 | |
| batch_g.ndata["coords"] = model_output[:, 0:3] | |
| batch_g.ndata["beta"] = model_output[:, 3] | |
| if not tracking: | |
| if e_corr is None: | |
| batch_g.ndata["correction"] = model_output[:, 4] | |
| graphs = dgl.unbatch(batch_g) | |
| batch_id = y.batch_number.view(-1) # y[:, -1].view(-1) | |
| df_list = [] | |
| df_list1 = [] | |
| df_list_pandora = [] | |
| total_number_candidates = 0 | |
| for i in range(0, len(graphs)): | |
| mask = batch_id == i | |
| dic = {} | |
| dic["graph"] = graphs[i] | |
| y1 = y.copy() | |
| y1.mask(mask) | |
| dic["part_true"] = y1 # y[mask] | |
| X = dic["graph"].ndata["coords"] | |
| # if shap_vals is not None: | |
| # dic["shap_values"] = shap_vals | |
| # if ec_x is not None: | |
| # dic["ec_x"] = ec_x ## ? No mask ?!? | |
| if predict: | |
| labels_clustering = clustering_obtain_labels( | |
| X, dic["graph"].ndata["beta"].view(-1), model_output.device | |
| ) | |
| if use_gt_clusters: | |
| labels_hdb = dic["graph"].ndata["particle_number"].type(torch.int64) | |
| else: | |
| labels_hdb = hfdb_obtain_labels(X, model_output.device) | |
| num_clusters = len(labels_hdb.unique()) | |
| #if labels_hdb.min() == 0 and labels_hdb.sum() == 0: | |
| # labels_hdb += 1 # Quick hack | |
| # raise Exception("!!!! Labels==0 !!!!") | |
| if predict: | |
| labels_pandora = get_labels_pandora(tracks, dic, model_output.device) | |
| num_clusters_pandora = len(labels_pandora.unique()) | |
| particle_ids = torch.unique(dic["graph"].ndata["particle_number"]) | |
| #current_number_candidates = num_clusters | |
| #pred_pos_batch = pred_pos[total_number_candidates:total_number_candidates+current_number_candidates] | |
| #pred_ref_pt_batch = pred_ref_pt[total_number_candidates:total_number_candidates+current_number_candidates] | |
| #pred_pid_batch = pred_pid[total_number_candidates:total_number_candidates+current_number_candidates] | |
| #e_corr_batch = e_corr[total_number_candidates:total_number_candidates+current_number_candidates] | |
| """if predict: | |
| shower_p_unique = torch.unique(labels_clustering) | |
| shower_p_unique, row_ind, col_ind, i_m_w, iou_m_c = match_showers( | |
| labels_clustering, | |
| dic, | |
| particle_ids, | |
| model_output, | |
| local_rank, | |
| i, | |
| path_save, | |
| tracks=tracks, | |
| )""" | |
| shower_p_unique_hdb, row_ind_hdb, col_ind_hdb, i_m_w_hdb, iou_m = match_showers( | |
| labels_hdb, | |
| dic, | |
| particle_ids, | |
| model_output, | |
| local_rank, | |
| i, | |
| path_save, | |
| tracks=tracks, | |
| hdbscan=True, | |
| ) | |
| if predict: | |
| ( | |
| shower_p_unique_pandora, | |
| row_ind_pandora, | |
| col_ind_pandora, | |
| i_m_w_pandora, | |
| iou_m_pandora, | |
| ) = match_showers( | |
| labels_pandora, | |
| dic, | |
| particle_ids, | |
| model_output, | |
| local_rank, | |
| i, | |
| path_save, | |
| pandora=True, | |
| tracks=tracks, | |
| ) | |
| # # if len(row_ind_hdb) < len(dic["part_true"]): | |
| # print(len(row_ind_hdb), len(dic["part_true"])) | |
| # print("storing event", local_rank, step, i) | |
| # path_graphs_all_comparing = os.path.join(path_save, "graphs_all_comparing") | |
| # if not os.path.exists(path_graphs_all_comparing): | |
| # os.makedirs(path_graphs_all_comparing) | |
| '''torch.save( | |
| dic, | |
| path_save | |
| + "/graphs_all_comparing_Gregor/" | |
| + str(local_rank) | |
| + "_" | |
| + str(step) | |
| + "_" | |
| + str(i) | |
| + ".pt", | |
| )''' | |
| # torch.save( | |
| # dic, | |
| # path_save | |
| # + "/graphs/" | |
| # + str(local_rank) | |
| # + "_" | |
| # + str(step) | |
| # + "_" | |
| # + str(i) | |
| # + ".pt", | |
| # ) | |
| if len(shower_p_unique_hdb) > 1: | |
| # df_event, number_of_showers_total = generate_showers_data_frame( | |
| # labels_clustering, | |
| # labels_clustering, | |
| # dic, | |
| # shower_p_unique, | |
| # particle_ids, | |
| # row_ind, | |
| # col_ind, | |
| # i_m_w, | |
| # e_corr=e_corr, | |
| # number_of_showers_total=number_of_showers_total, | |
| # step=step, | |
| # number_in_batch=i, | |
| # tracks=tracks, | |
| # ) | |
| # if pred_pos is not None: | |
| # Apply temporary correction | |
| import math | |
| # phi = math.atan2(pred_pos[:, 1], pred_pos[:, 0]) | |
| # phi = torch.atan2(pred_pos[:, 1], pred_pos[:, 0]) | |
| # theta = torch.acos(pred_pos[:, 2] / torch.norm(pred_pos, dim=1)) | |
| # pred_pos = spherical_to_cartesian(theta, phi, torch.norm(pred_pos, dim=1), normalized=True) | |
| # pred_pos= pred_pos.to(model_output.device) | |
| df_event1, number_of_showers_total1, number_of_fake_showers_total1 = generate_showers_data_frame( | |
| labels_hdb, | |
| dic, | |
| shower_p_unique_hdb, | |
| particle_ids, | |
| row_ind_hdb, | |
| col_ind_hdb, | |
| i_m_w_hdb, | |
| e_corr=e_corr, | |
| number_of_showers_total=number_of_showers_total1, | |
| step=step, | |
| number_in_batch=total_number_events, | |
| tracks=tracks, | |
| ec_x=ec_x, | |
| shap_vals=shap_vals, | |
| pred_pos=pred_pos, | |
| pred_ref_pt=pred_ref_pt, | |
| pred_pid=pred_pid, | |
| save_plots_to_folder=path_save + "/ML_Model_evt_plots_debugging", | |
| number_of_fakes=number_of_fakes, | |
| number_of_fake_showers_total=number_of_fake_showers_total1, | |
| ) | |
| if len(df_event1) > 1: | |
| df_list1.append(df_event1) | |
| if predict: | |
| df_event_pandora = generate_showers_data_frame( | |
| labels_pandora, | |
| dic, | |
| shower_p_unique_pandora, | |
| particle_ids, | |
| row_ind_pandora, | |
| col_ind_pandora, | |
| i_m_w_pandora, | |
| pandora=True, | |
| tracking=tracking, | |
| step=step, | |
| number_in_batch=total_number_events, | |
| tracks=tracks, | |
| save_plots_to_folder=path_save + "/Pandora_evt_plots_debugging", | |
| ) | |
| if df_event_pandora is not None and type(df_event_pandora) is not tuple: | |
| df_list_pandora.append(df_event_pandora) | |
| total_number_events = total_number_events + 1 | |
| # print("number of showers total", number_of_showers_total) | |
| # number_of_showers_total = number_of_showers_total + len(shower_p_unique_hdb) | |
| # print("number of showers total", number_of_showers_total) | |
| df_batch1 = pd.concat(df_list1) | |
| if predict: | |
| df_batch_pandora = pd.concat(df_list_pandora) | |
| else: | |
| df_batch = [] | |
| df_batch_pandora = [] | |
| # | |
| if store: | |
| store_at_batch_end( | |
| path_save, | |
| df_batch1, | |
| df_batch_pandora, | |
| # df_batch, | |
| local_rank, | |
| step, | |
| epoch, | |
| predict=predict, | |
| store=store_epoch, | |
| ) | |
| if predict: | |
| return df_batch_pandora, df_batch1, total_number_events | |
| else: | |
| return df_batch1 | |
| def store_at_batch_end( | |
| path_save, | |
| df_batch1, | |
| df_batch_pandora, | |
| # df_batch, | |
| local_rank=0, | |
| step=0, | |
| epoch=None, | |
| predict=False, | |
| store=False, | |
| ): | |
| if predict: | |
| path_save_ = ( | |
| path_save | |
| + "/" | |
| + str(local_rank) | |
| + "_" | |
| + str(step) | |
| + "_" | |
| + str(epoch) | |
| + ".pt" | |
| ) | |
| # if store and predict: | |
| # df_batch.to_pickle(path_save_) | |
| # log_efficiency(df_batch, clustering=True) | |
| path_save_ = ( | |
| path_save | |
| + "/" | |
| + str(local_rank) | |
| + "_" | |
| + str(step) | |
| + "_" | |
| + str(epoch) | |
| + "_hdbscan.pt" | |
| ) | |
| if store and predict: | |
| df_batch1.to_pickle(path_save_) | |
| if predict: | |
| path_save_pandora = ( | |
| path_save | |
| + "/" | |
| + str(local_rank) | |
| + "_" | |
| + str(step) | |
| + "_" | |
| + str(epoch) | |
| + "_pandora.pt" | |
| ) | |
| if store and predict: | |
| df_batch_pandora.to_pickle(path_save_pandora) | |
| log_efficiency(df_batch1) | |
| if predict: | |
| log_efficiency(df_batch_pandora, pandora=True) | |
| def log_efficiency(df, pandora=False, clustering=False): | |
| mask = ~np.isnan(df["reco_showers_E"]) | |
| eff = np.sum(~np.isnan(df["pred_showers_E"][mask].values)) / len( | |
| df["pred_showers_E"][mask].values | |
| ) | |
| if pandora: | |
| wandb.log({"efficiency validation pandora": eff}) | |
| elif clustering: | |
| wandb.log({"efficiency validation clustering": eff}) | |
| else: | |
| wandb.log({"efficiency validation": eff}) | |
| def generate_showers_data_frame( | |
| labels, | |
| dic, | |
| shower_p_unique, | |
| particle_ids, | |
| row_ind, | |
| col_ind, | |
| i_m_w, | |
| pandora=False, | |
| tracking=False, | |
| e_corr=None, | |
| number_of_showers_total=None, | |
| step=0, | |
| number_in_batch=0, | |
| tracks=False, | |
| shap_vals=None, | |
| ec_x=None, | |
| pred_pos=None, | |
| pred_pid=None, | |
| save_plots_to_folder="", | |
| pred_ref_pt=None, | |
| number_of_fake_showers_total=None, | |
| number_of_fakes=None | |
| ): | |
| shap = shap_vals is not None | |
| e_pred_showers = scatter_add(dic["graph"].ndata["e_hits"].view(-1), labels) | |
| if pandora: | |
| e_pred_showers_cali = scatter_mean( | |
| dic["graph"].ndata["pandora_cluster_energy"].view(-1), labels | |
| ) | |
| e_pred_showers_pfo = scatter_mean( | |
| dic["graph"].ndata["pandora_pfo_energy"].view(-1), labels | |
| ) | |
| # px_pred_pfo = scatter_mean(dic["graph"].ndata["hit_px"], labels) | |
| # py_pred_pfo = scatter_mean(dic["graph"].ndata["hit_py"], labels) | |
| # pz_pred_pfo = scatter_mean(dic["graph"].ndata["hit_pz"], labels) | |
| # p_pred_pfo = scatter_mean(dic["graph"].ndata["pos_pxpypz"], labels) # FIX THIS: the shape of pos_pxpypz is [-1, 3] | |
| calc_pandora_momentum = "pandora_momentum" in dic["graph"].ndata | |
| if calc_pandora_momentum: | |
| px_pred_pfo = scatter_mean( | |
| dic["graph"].ndata["pandora_momentum"][:, 0], labels | |
| ) | |
| py_pred_pfo = scatter_mean( | |
| dic["graph"].ndata["pandora_momentum"][:, 1], labels | |
| ) | |
| pz_pred_pfo = scatter_mean( | |
| dic["graph"].ndata["pandora_momentum"][:, 2], labels | |
| ) | |
| ref_pt_px_pred_pfo = scatter_mean( | |
| dic["graph"].ndata["pandora_reference_point"][:, 0], labels | |
| ) | |
| ref_pt_py_pred_pfo = scatter_mean( | |
| dic["graph"].ndata["pandora_reference_point"][:, 1], labels | |
| ) | |
| ref_pt_pz_pred_pfo = scatter_mean( | |
| dic["graph"].ndata["pandora_reference_point"][:, 2], labels | |
| ) | |
| pandora_pid = scatter_mean( | |
| dic["graph"].ndata["pandora_pid"], labels | |
| ) | |
| ref_pt_pred_pfo = torch.stack( | |
| (ref_pt_px_pred_pfo, ref_pt_py_pred_pfo, ref_pt_pz_pred_pfo), dim=1 | |
| ) | |
| # p_pred_pandora = scatter_mean(dic["graph"].ndata["pandora_momentum"], labels) | |
| p_pred_pandora = torch.stack((px_pred_pfo, py_pred_pfo, pz_pred_pfo), dim=1) | |
| p_size_pandora = torch.norm(p_pred_pandora, dim=1) | |
| pxyz_pred_pfo = ( | |
| p_pred_pandora # / torch.norm(p_pred_pandora, dim=1).view(-1, 1) | |
| ) | |
| else: | |
| if e_corr is None: | |
| corrections_per_shower = get_correction_per_shower(labels, dic) | |
| e_pred_showers_cali = e_pred_showers * corrections_per_shower | |
| else: | |
| corrections_per_shower = e_corr.view(-1) | |
| if number_of_fakes > 0: | |
| corrections_per_shower_fakes = corrections_per_shower[-number_of_fakes:] | |
| corrections_per_shower = corrections_per_shower[:-number_of_fakes] | |
| e_reco_showers = scatter_add( | |
| dic["graph"].ndata["e_hits"].view(-1), | |
| dic["graph"].ndata["particle_number"].long(), | |
| ) | |
| row_ind = torch.Tensor(row_ind).to(e_pred_showers.device).long() | |
| col_ind = torch.Tensor(col_ind).to(e_pred_showers.device).long() | |
| if torch.sum(particle_ids == 0) > 0: | |
| # particle id can be 0 because there is noise | |
| # then row ind 0 in any case corresponds to particle 1. | |
| # if there is particle_id 0 then row_ind should be +1? | |
| row_ind_ = row_ind - 1 | |
| else: | |
| # if there is no zero then index 0 corresponds to particle 1. | |
| row_ind_ = row_ind | |
| pred_showers = shower_p_unique | |
| energy_t = ( | |
| dic["part_true"].E_corrected.view(-1).to(e_pred_showers.device) | |
| ) # dic["part_true"][:, 3].to(e_pred_showers.device) | |
| vertex = dic["part_true"].vertex.to(e_pred_showers.device) | |
| pos_t = dic["part_true"].coord.to(e_pred_showers.device) | |
| pid_t = dic["part_true"].pid.to(e_pred_showers.device) | |
| is_track_per_shower = scatter_add((dic["graph"].ndata["hit_type"] == 1), labels).int() | |
| is_track = torch.zeros(energy_t.shape).to(e_pred_showers.device) | |
| if shap: | |
| matched_shap_vals = torch.zeros((energy_t.shape[0], ec_x.shape[1])) * ( | |
| torch.nan | |
| ) | |
| matched_shap_vals = matched_shap_vals.numpy() | |
| matched_ec_x = torch.zeros((energy_t.shape[0], ec_x.shape[1])) * (torch.nan) | |
| matched_ec_x = matched_ec_x.numpy() | |
| index_matches = col_ind + 1 | |
| index_matches = index_matches.to(e_pred_showers.device).long() | |
| matched_es = torch.zeros_like(energy_t) * (torch.nan) | |
| matched_positions = torch.zeros((energy_t.shape[0], 3)) * (torch.nan) | |
| matched_positions = matched_positions.to(e_pred_showers.device) | |
| matched_ref_pt = torch.zeros((energy_t.shape[0], 3)) * (torch.nan) | |
| matched_ref_pt = matched_ref_pt.to(e_pred_showers.device) | |
| matched_pid = torch.zeros_like(energy_t) * (torch.nan) | |
| matched_pid = matched_pid.to(e_pred_showers.device).long() | |
| matched_positions_pfo = torch.zeros((energy_t.shape[0], 3)) * (torch.nan) | |
| matched_positions_pfo = matched_positions_pfo.to(e_pred_showers.device) | |
| matched_pandora_pid = (torch.zeros((energy_t.shape[0])) * (torch.nan)).to(e_pred_showers.device) | |
| matched_ref_pts_pfo = torch.zeros((energy_t.shape[0], 3)) * (torch.nan) | |
| matched_ref_pts_pfo = matched_ref_pts_pfo.to(e_pred_showers.device) | |
| matched_es = matched_es.to(e_pred_showers.device) | |
| matched_es[row_ind_] = e_pred_showers[index_matches] | |
| if pandora: | |
| matched_es_cali = matched_es.clone() | |
| matched_es_cali[row_ind_] = e_pred_showers_cali[index_matches] | |
| matched_es_cali_pfo = matched_es.clone() | |
| matched_es_cali_pfo[row_ind_] = e_pred_showers_pfo[index_matches] | |
| matched_pandora_pid[row_ind_] = pandora_pid[index_matches] | |
| if calc_pandora_momentum: | |
| matched_positions_pfo[row_ind_] = pxyz_pred_pfo[index_matches] | |
| matched_ref_pts_pfo[row_ind_] = ref_pt_pred_pfo[index_matches] | |
| is_track[row_ind_] = is_track_per_shower[index_matches].float() | |
| else: | |
| if e_corr is None: | |
| matched_es_cali = matched_es.clone() | |
| matched_es_cali[row_ind_] = e_pred_showers_cali[index_matches] | |
| calibration_per_shower = matched_es.clone() | |
| calibration_per_shower[row_ind_] = corrections_per_shower[index_matches] | |
| else: | |
| matched_es_cali = matched_es.clone() | |
| number_of_showers = e_pred_showers[index_matches].shape[0] # DOESN'T INCLUDE THE FAKE SHOWERS | |
| #number_of_fake_showers = e_pred_showers.shape[0] - number_of_showers | |
| matched_es_cali[row_ind_] = ( | |
| corrections_per_shower[ | |
| number_of_showers_total : number_of_showers_total | |
| + number_of_showers | |
| ] | |
| #* e_pred_showers[index_matches] | |
| ) | |
| # if len(row_ind) and len(index_matches): | |
| # assert row_ind.max() < len(is_track) | |
| # assert index_matches.max() < len(is_track_per_shower) | |
| is_track[row_ind_] = is_track_per_shower[index_matches].float() | |
| if pred_pos is not None: | |
| matched_positions[row_ind_] = pred_pos[number_of_showers_total : number_of_showers_total | |
| + number_of_showers] | |
| matched_ref_pt[row_ind_] = pred_ref_pt[number_of_showers_total : number_of_showers_total + number_of_showers] | |
| matched_pid[row_ind_] = pred_pid[number_of_showers_total : number_of_showers_total + number_of_showers] | |
| if shap: | |
| matched_shap_vals[row_ind_.cpu()] = shap_vals[index_matches.cpu()] | |
| matched_ec_x[row_ind_.cpu()] = ec_x[index_matches.cpu()] | |
| calibration_per_shower = matched_es.clone() | |
| calibration_per_shower[row_ind_] = corrections_per_shower[ | |
| number_of_showers_total : number_of_showers_total + number_of_showers | |
| ] | |
| number_of_showers_total = number_of_showers_total + number_of_showers | |
| intersection_E = torch.zeros_like(energy_t) * (torch.nan) | |
| if len(col_ind) > 0: | |
| ie_e = obtain_intersection_values(i_m_w, row_ind, col_ind, dic) | |
| intersection_E[row_ind_] = ie_e.to(e_pred_showers.device) | |
| pred_showers[index_matches] = -1 | |
| pred_showers[ | |
| 0 | |
| ] = ( | |
| -1 | |
| ) # This takes into account that the class 0 for pandora and for dbscan is noise | |
| mask = pred_showers != -1 | |
| number_of_fake_showers = mask.sum() | |
| fakes_in_event = mask.sum() | |
| fake_showers_e = e_pred_showers[mask] | |
| if e_corr is None or pandora: | |
| fake_showers_e_cali = e_pred_showers_cali[mask] | |
| # fakes_positions = dic["graph"].ndata["coords"][mask] | |
| else: | |
| #fake_showers_e_cali = corrections_per_shower[number_of_showers_total:number_of_showers_total+number_of_showers][mask]# * (torch.nan) | |
| #fakes_positions = torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan) | |
| #fake_showers_e_cali = fake_showers_e | |
| #fakes_pid_pred = torch.zeros((fake_showers_e.shape[0])) * (torch.nan) # just for now for debugigng | |
| #fakes_positions = fakes_positions.to(e_pred_showers.device) | |
| #fakes_pid_pred = fakes_pid_pred.to(e_pred_showers.device) | |
| fakes_positions = pred_pos[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total+number_of_fake_showers] | |
| fake_showers_e_cali = e_corr[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total+number_of_fake_showers] | |
| fakes_pid_pred = pred_pid[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total+number_of_fake_showers] | |
| fake_showers_e_reco = e_reco_showers[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total+number_of_fake_showers] | |
| fakes_positions = fakes_positions.to(e_pred_showers.device) | |
| fake_showers_e_cali = fake_showers_e_cali.to(e_pred_showers.device) | |
| fakes_pid_pred = fakes_pid_pred.to(e_pred_showers.device) | |
| fake_showers_e_reco = fake_showers_e_reco.to(e_pred_showers.device) | |
| #fakes_pid_pred = pred_pid[number_of_showers_total:number_of_showers_total+number_of_showers][mask] | |
| #fakes_positions = fakes_positions.to(e_pred_showers.device) | |
| if pandora: | |
| fake_pandora_pid = (torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan)).to(e_pred_showers.device) | |
| fake_pandora_pid = pandora_pid[mask] | |
| if calc_pandora_momentum: | |
| fake_positions_pfo = torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan) | |
| fake_positions_pfo = fake_positions_pfo.to(e_pred_showers.device) | |
| fake_positions_pfo = pxyz_pred_pfo[mask] | |
| fakes_positions_ref = (torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan)).to(e_pred_showers.device) | |
| fakes_positions_ref = ref_pt_pred_pfo[mask] | |
| if not pandora: | |
| if e_corr is None: | |
| fake_showers_e_cali_factor = corrections_per_shower[mask] | |
| else: | |
| fake_showers_e_cali_factor = fake_showers_e_cali | |
| fake_showers_showers_e_truw = torch.zeros((fake_showers_e.shape[0])) * ( | |
| torch.nan | |
| ) | |
| fake_showers_vertex = torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan) | |
| fakes_is_track = (torch.zeros((fake_showers_e.shape[0])) * (torch.nan)).to(e_pred_showers.device) | |
| fakes_is_track = is_track_per_shower[mask] | |
| fakes_positions_t = torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan) | |
| if not pandora: | |
| number_of_fake_showers_total = number_of_fake_showers_total + number_of_fake_showers | |
| """if shap: | |
| fake_showers_shap_vals = torch.zeros((fake_showers_e.shape[0], shap_vals_t.shape[1])) * ( | |
| torch.nan | |
| ) | |
| fake_showers_ec_x_t = torch.zeros((fake_showers_e.shape[0], ec_x_t.shape[1])) * ( | |
| torch.nan | |
| ) | |
| #fake_showers_shap_vals = fake_showers_shap_vals.to(e_pred_showers.device) | |
| #fake_showers_ec_x_t = fake_showers_ec_x_t.to(e_pred_showers.device) | |
| shap_vals_t = torch.cat((torch.tensor(shap_vals_t), fake_showers_shap_vals), dim=0) | |
| ec_x_t = torch.cat((torch.tensor(ec_x_t), fake_showers_ec_x_t), dim=0) | |
| """ | |
| fake_showers_showers_e_truw = fake_showers_showers_e_truw.to( | |
| e_pred_showers.device | |
| ) | |
| fakes_positions_t = fakes_positions_t.to(e_pred_showers.device) | |
| fake_showers_vertex = fake_showers_vertex.to(e_pred_showers.device) | |
| energy_t = torch.cat( | |
| (energy_t, fake_showers_showers_e_truw), | |
| dim=0, | |
| ) | |
| vertex = torch.cat((vertex, fake_showers_vertex), dim=0) | |
| pid_t = torch.cat( | |
| (pid_t.view(-1), fake_showers_showers_e_truw), | |
| dim=0, | |
| ) | |
| pos_t = torch.cat( | |
| (pos_t, fakes_positions_t), | |
| dim=0, | |
| ) | |
| e_reco = torch.cat((e_reco_showers[1:], fake_showers_showers_e_truw), dim=0) | |
| e_pred = torch.cat((matched_es, fake_showers_e), dim=0) | |
| e_pred_cali = torch.cat((matched_es_cali, fake_showers_e_cali), dim=0) | |
| if pred_pos is not None: | |
| e_pred_pos = torch.cat((matched_positions, fakes_positions), dim=0) | |
| e_pred_pid = torch.cat((matched_pid, fakes_pid_pred), dim=0) | |
| e_pred_ref_pt = torch.cat((matched_ref_pt, fakes_positions), dim=0) | |
| if pandora: | |
| e_pred_cali_pfo = torch.cat( | |
| (matched_es_cali_pfo, fake_showers_e_cali), dim=0 | |
| ) | |
| positions_pfo = torch.cat((matched_positions_pfo, fake_positions_pfo), dim=0) | |
| pandora_pid = torch.cat((matched_pandora_pid, fake_pandora_pid), dim=0) | |
| ref_pts_pfo = torch.cat((matched_ref_pts_pfo, fakes_positions_ref), dim=0) | |
| if not pandora: | |
| calibration_factor = torch.cat( | |
| (calibration_per_shower, fake_showers_e_cali_factor), dim=0 | |
| ) | |
| if shap: | |
| # pad | |
| matched_shap_vals = torch.cat( | |
| ( | |
| torch.tensor(matched_shap_vals), | |
| torch.zeros((fake_showers_e.shape[0], shap_vals.shape[1])), | |
| ), | |
| dim=0, | |
| ) | |
| matched_ec_x = torch.cat( | |
| ( | |
| torch.tensor(matched_ec_x), | |
| torch.zeros((fake_showers_e.shape[0], ec_x.shape[1])), | |
| ), | |
| dim=0, | |
| ) | |
| e_pred_t = torch.cat( | |
| ( | |
| intersection_E, | |
| torch.zeros_like(fake_showers_e) * (torch.nan), | |
| ), | |
| dim=0, | |
| ) | |
| # e_pred_t_pandora = torch.cat( | |
| # ( | |
| # intersection_E, | |
| # torch.zeros_like(fake_showers_e) * (-200), | |
| # torch.zeros_like(fake_showers_e_pandora) * (-100), | |
| # ), | |
| # dim=0, | |
| # ) | |
| is_track = torch.cat((is_track, fakes_is_track.to(is_track.device)), dim=0) | |
| if pandora: | |
| d = { | |
| "true_showers_E": energy_t.detach().cpu(), | |
| "reco_showers_E": e_reco.detach().cpu(), | |
| "pred_showers_E": e_pred.detach().cpu(), | |
| "e_pred_and_truth": e_pred_t.detach().cpu(), | |
| "pandora_calibrated_E": e_pred_cali.detach().cpu(), | |
| "pandora_calibrated_pfo": e_pred_cali_pfo.detach().cpu(), | |
| "pandora_calibrated_pos": positions_pfo.detach().cpu().tolist(), | |
| "pandora_ref_pt": ref_pts_pfo.detach().cpu().tolist(), | |
| "pid": pid_t.detach().cpu(), | |
| "pandora_pid":pandora_pid.detach().cpu(), | |
| "step": torch.ones_like(energy_t.detach().cpu()) * step, | |
| "number_batch": torch.ones_like(energy_t.detach().cpu()) | |
| * number_in_batch, | |
| "is_track_in_cluster": is_track.detach().cpu(), | |
| "vertex": vertex.detach().cpu().tolist() | |
| } | |
| else: | |
| d = { | |
| "true_showers_E": energy_t.detach().cpu(), | |
| "reco_showers_E": e_reco.detach().cpu(), | |
| "pred_showers_E": e_pred.detach().cpu(), | |
| "e_pred_and_truth": e_pred_t.detach().cpu(), | |
| "pid": pid_t.detach().cpu(), | |
| "calibration_factor": calibration_factor.detach().cpu(), | |
| "calibrated_E": e_pred_cali.detach().cpu(), | |
| "step": torch.ones_like(energy_t.detach().cpu()) * step, | |
| "number_batch": torch.ones_like(energy_t.detach().cpu()) | |
| * number_in_batch, | |
| "is_track_in_cluster": is_track.detach().cpu(), | |
| "vertex": vertex.detach().cpu().tolist() | |
| } | |
| if pred_pos is not None: | |
| pred_pos1 = e_pred_pos.detach().cpu() | |
| pred_pid1 = e_pred_pid.detach().cpu() | |
| pred_ref_pt1 = e_pred_ref_pt.detach().cpu() | |
| d["pred_pos_matched"] = ( | |
| pred_pos1.tolist() | |
| ) # Otherwise it doesn't work nicely with Pandas DataFrames | |
| d["pred_pid_matched"] = pred_pid1.tolist() | |
| d["pred_ref_pt_matched"] = pred_ref_pt1.tolist() | |
| """if shap: | |
| print("Adding ec_x and shap_values to the DataFrame") | |
| d["ec_x"] = ec_x_t | |
| d["shap_values"] = shap_vals_t""" | |
| if shap: | |
| d["shap_values"] = matched_shap_vals.tolist() | |
| d["ec_x"] = matched_ec_x.tolist() | |
| d["true_pos"] = pos_t.detach().cpu().tolist() | |
| df = pd.DataFrame(data=d) | |
| if save_plots_to_folder: | |
| event_numbers = np.unique(df.number_batch) | |
| for evt in event_numbers: | |
| if len(df[df.number_batch == evt]): | |
| # Random string | |
| rndstr = generate_random_string(5) | |
| plot_event( | |
| df[df.number_batch == evt], | |
| pandora, | |
| save_plots_to_folder + str(evt) + rndstr, | |
| graph=dic["graph"].to("cpu"), | |
| y=dic["part_true"], | |
| labels=dic["graph"].ndata["particle_number"].long(), | |
| is_track_in_cluster=df.is_track_in_cluster | |
| ) | |
| '''plot_event( | |
| df[df.number_batch == evt], | |
| pandora, | |
| save_plots_to_folder + "_CLUSTERING_" + str(evt) + rndstr, | |
| graph=dic["graph"].to("cpu"), | |
| y=dic["part_true"], | |
| labels=labels.detach().cpu(), | |
| is_track_in_cluster=df.is_track_in_cluster | |
| )''' | |
| if number_of_showers_total is None: | |
| return df | |
| else: | |
| return df, number_of_showers_total, number_of_fake_showers_total | |
| else: | |
| return [], 0, 0 | |
| def get_correction_per_shower(labels, dic): | |
| unique_labels = torch.unique(labels) | |
| list_corr = [] | |
| for ii, pred_label in enumerate(unique_labels): | |
| if ii == 0: | |
| if pred_label != 0: | |
| list_corr.append(dic["graph"].ndata["correction"][0].view(-1) * 0) | |
| mask = labels == pred_label | |
| corrections_E_label = dic["graph"].ndata["correction"][mask] | |
| betas_label_indmax = torch.argmax(dic["graph"].ndata["beta"][mask]) | |
| list_corr.append(corrections_E_label[betas_label_indmax].view(-1)) | |
| corrections = torch.cat(list_corr, dim=0) | |
| return corrections | |
| def obtain_intersection_matrix(shower_p_unique, particle_ids, labels, dic, e_hits): | |
| len_pred_showers = len(shower_p_unique) | |
| intersection_matrix = torch.zeros((len_pred_showers, len(particle_ids))).to( | |
| shower_p_unique.device | |
| ) | |
| intersection_matrix_w = torch.zeros((len_pred_showers, len(particle_ids))).to( | |
| shower_p_unique.device | |
| ) | |
| for index, id in enumerate(particle_ids): | |
| counts = torch.zeros_like(labels) | |
| mask_p = dic["graph"].ndata["particle_number"] == id | |
| h_hits = e_hits.clone() | |
| counts[mask_p] = 1 | |
| h_hits[~mask_p] = 0 | |
| intersection_matrix[:, index] = scatter_add(counts, labels) | |
| intersection_matrix_w[:, index] = scatter_add(h_hits, labels.to(h_hits.device)) | |
| return intersection_matrix, intersection_matrix_w | |
| def obtain_union_matrix(shower_p_unique, particle_ids, labels, dic): | |
| len_pred_showers = len(shower_p_unique) | |
| union_matrix = torch.zeros((len_pred_showers, len(particle_ids))) | |
| for index, id in enumerate(particle_ids): | |
| counts = torch.zeros_like(labels) | |
| mask_p = dic["graph"].ndata["particle_number"] == id | |
| for index_pred, id_pred in enumerate(shower_p_unique): | |
| mask_pred_p = labels == id_pred | |
| mask_union = mask_pred_p + mask_p | |
| union_matrix[index_pred, index] = torch.sum(mask_union) | |
| return union_matrix | |
| def get_clustering(betas: torch.Tensor, X: torch.Tensor, tbeta=0.7, td=0.03): | |
| """ | |
| Returns a clustering of hits -> cluster_index, based on the GravNet model | |
| output (predicted betas and cluster space coordinates) and the clustering | |
| parameters tbeta and td. | |
| Takes torch.Tensors as input. | |
| """ | |
| n_points = betas.size(0) | |
| select_condpoints = betas > tbeta | |
| # Get indices passing the threshold | |
| indices_condpoints = select_condpoints.nonzero() | |
| # Order them by decreasing beta value | |
| indices_condpoints = indices_condpoints[(-betas[select_condpoints]).argsort()] | |
| # Assign points to condensation points | |
| # Only assign previously unassigned points (no overwriting) | |
| # Points unassigned at the end are bkg (-1) | |
| unassigned = torch.arange(n_points).to(betas.device) | |
| clustering = -1 * torch.ones(n_points, dtype=torch.long).to(betas.device) | |
| while len(indices_condpoints) > 0 and len(unassigned) > 0: | |
| index_condpoint = indices_condpoints[0] | |
| d = torch.norm(X[unassigned] - X[index_condpoint][0], dim=-1) | |
| assigned_to_this_condpoint = unassigned[d < td] | |
| clustering[assigned_to_this_condpoint] = index_condpoint[0] | |
| unassigned = unassigned[~(d < td)] | |
| # calculate indices_codpoints again | |
| indices_condpoints = find_condpoints(betas, unassigned, tbeta) | |
| return clustering | |
| def find_condpoints(betas, unassigned, tbeta): | |
| n_points = betas.size(0) | |
| select_condpoints = betas > tbeta | |
| device = betas.device | |
| mask_unassigned = torch.zeros(n_points).to(device) | |
| mask_unassigned[unassigned] = True | |
| select_condpoints = mask_unassigned.to(bool) * select_condpoints | |
| # Get indices passing the threshold | |
| indices_condpoints = select_condpoints.nonzero() | |
| # Order them by decreasing beta value | |
| indices_condpoints = indices_condpoints[(-betas[select_condpoints]).argsort()] | |
| return indices_condpoints | |
| def obtain_intersection_values(intersection_matrix_w, row_ind, col_ind, dic): | |
| list_intersection_E = [] | |
| # intersection_matrix_w = intersection_matrix_w | |
| particle_ids = torch.unique(dic["graph"].ndata["particle_number"]) | |
| if torch.sum(particle_ids == 0) > 0: | |
| # removing also the MC particle corresponding to noise | |
| intersection_matrix_wt = torch.transpose(intersection_matrix_w[1:, 1:], 1, 0) | |
| row_ind = row_ind - 1 | |
| else: | |
| intersection_matrix_wt = torch.transpose(intersection_matrix_w[1:, :], 1, 0) | |
| for i in range(0, len(col_ind)): | |
| list_intersection_E.append( | |
| intersection_matrix_wt[row_ind[i], col_ind[i]].view(-1) | |
| ) | |
| if len(list_intersection_E) > 0: | |
| return torch.cat(list_intersection_E, dim=0) | |
| else: | |
| return 0 | |
| def plot_iou_matrix(iou_matrix, image_path, hdbscan=False): | |
| iou_matrix = torch.transpose(iou_matrix[1:, :], 1, 0) | |
| fig, ax = plt.subplots() | |
| iou_matrix = iou_matrix.detach().cpu().numpy() | |
| ax.matshow(iou_matrix, cmap=plt.cm.Blues) | |
| for i in range(0, iou_matrix.shape[1]): | |
| for j in range(0, iou_matrix.shape[0]): | |
| c = np.round(iou_matrix[j, i], 1) | |
| ax.text(i, j, str(c), va="center", ha="center") | |
| fig.savefig(image_path, bbox_inches="tight") | |
| if hdbscan: | |
| wandb.log({"iou_matrix_hdbscan": wandb.Image(image_path)}) | |
| else: | |
| wandb.log({"iou_matrix": wandb.Image(image_path)}) | |
| def match_showers( | |
| labels, | |
| dic, | |
| particle_ids, | |
| model_output, | |
| local_rank, | |
| i, | |
| path_save, | |
| pandora=False, | |
| tracks=False, | |
| hdbscan=False, | |
| ): | |
| iou_threshold = 0.25 | |
| shower_p_unique = torch.unique(labels) | |
| if torch.sum(labels == 0) == 0: | |
| shower_p_unique = torch.cat( | |
| ( | |
| torch.Tensor([0]).to(shower_p_unique.device).view(-1), | |
| shower_p_unique.view(-1), | |
| ), | |
| dim=0, | |
| ) | |
| e_hits = dic["graph"].ndata["e_hits"].view(-1) | |
| i_m, i_m_w = obtain_intersection_matrix( | |
| shower_p_unique, particle_ids, labels, dic, e_hits | |
| ) | |
| i_m = i_m.to(model_output.device) | |
| i_m_w = i_m_w.to(model_output.device) | |
| u_m = obtain_union_matrix(shower_p_unique, particle_ids, labels, dic) | |
| u_m = u_m.to(model_output.device) | |
| iou_matrix = i_m / u_m | |
| if torch.sum(particle_ids == 0) > 0: | |
| # removing also the MC particle corresponding to noise | |
| iou_matrix_num = ( | |
| torch.transpose(iou_matrix[1:, 1:], 1, 0).clone().detach().cpu().numpy() | |
| ) | |
| else: | |
| iou_matrix_num = ( | |
| torch.transpose(iou_matrix[1:, :], 1, 0).clone().detach().cpu().numpy() | |
| ) | |
| iou_matrix_num[iou_matrix_num < iou_threshold] = 0 | |
| row_ind, col_ind = linear_sum_assignment(-iou_matrix_num) | |
| # Next three lines remove solutions where there is a shower that is not associated and iou it's zero (or less than threshold) | |
| mask_matching_matrix = iou_matrix_num[row_ind, col_ind] > 0 | |
| row_ind = row_ind[mask_matching_matrix] | |
| col_ind = col_ind[mask_matching_matrix] | |
| if torch.sum(particle_ids == 0) > 0: | |
| row_ind = row_ind + 1 | |
| if i == 0 and local_rank == 0: | |
| if path_save is not None: | |
| if pandora: | |
| image_path = path_save + "/example_1_clustering_pandora.png" | |
| else: | |
| image_path = path_save + "/example_1_clustering.png" | |
| # plot_iou_matrix(iou_matrix, image_path, hdbscan) | |
| # row_ind are particles that are matched and col_ind the ind of preds they are matched to | |
| return shower_p_unique, row_ind, col_ind, i_m_w, iou_matrix | |
| def clustering_obtain_labels(X, betas, device): | |
| clustering = get_clustering(betas, X) | |
| map_from = list(np.unique(clustering.detach().cpu())) | |
| cluster_id = map(lambda x: map_from.index(x), clustering.detach().cpu()) | |
| clustering_ordered = torch.Tensor(list(cluster_id)).long() | |
| if torch.unique(clustering)[0] != -1: | |
| clustering = clustering_ordered + 1 | |
| else: | |
| clustering = clustering_ordered | |
| clustering = torch.Tensor(clustering.view(-1)).long().to(device) | |
| return clustering | |
| def hfdb_obtain_labels(X, device, eps=0.1): | |
| hdb = HDBSCAN(min_cluster_size=8, min_samples=8, cluster_selection_epsilon=eps).fit( | |
| X.detach().cpu() | |
| ) | |
| labels_hdb = hdb.labels_ + 1 | |
| labels_hdb = np.reshape(labels_hdb, (-1)) | |
| labels_hdb = torch.Tensor(labels_hdb).long().to(device) | |
| return labels_hdb | |
| def dbscan_obtain_labels(X, device): | |
| distance_scale = ( | |
| (torch.min(torch.abs(torch.min(X, dim=0)[0] - torch.max(X, dim=0)[0])) / 30) | |
| .view(-1) | |
| .detach() | |
| .cpu() | |
| .numpy()[0] | |
| ) | |
| db = DBSCAN(eps=distance_scale, min_samples=15).fit(X.detach().cpu()) | |
| # DBSCAN has clustering labels -1,0,.., our cluster 0 is noise so we add 1 | |
| labels = db.labels_ + 1 | |
| labels = np.reshape(labels, (-1)) | |
| labels = torch.Tensor(labels).long().to(device) | |
| return labels | |
| class CachedIndexList: | |
| def __init__(self, lst): | |
| self.lst = lst | |
| self.cache = {} | |
| def index(self, value): | |
| if value in self.cache: | |
| return self.cache[value] | |
| else: | |
| idx = self.lst.index(value) | |
| self.cache[value] = idx | |
| return idx | |
| def get_labels_pandora(tracks, dic, device): | |
| if tracks: | |
| labels_pandora = dic["graph"].ndata["pandora_pfo"].long() | |
| else: | |
| labels_pandora = dic["graph"].ndata["pandora_cluster"].long() | |
| labels_pandora = labels_pandora + 1 | |
| map_from = list(np.unique(labels_pandora.detach().cpu())) | |
| map_from = CachedIndexList(map_from) | |
| cluster_id = map(lambda x: map_from.index(x), labels_pandora.detach().cpu().numpy()) | |
| labels_pandora = torch.Tensor(list(cluster_id)).long().to(device) | |
| return labels_pandora | |