Spaces:
Sleeping
Sleeping
| from lightning.pytorch.callbacks import BaseFinetuning | |
| import torch | |
| import dgl | |
| from src.layers.inference_oc import DPC_custom_CLD | |
| from src.layers.inference_oc import match_showers | |
| from src.layers.inference_oc import remove_bad_tracks_from_cluster | |
| class FreezeClustering(BaseFinetuning): | |
| def __init__( | |
| self, | |
| ): | |
| super().__init__() | |
| def freeze_before_training(self, pl_module): | |
| self.freeze(pl_module.ScaledGooeyBatchNorm2_1) | |
| self.freeze(pl_module.gatr) | |
| self.freeze(pl_module.clustering) | |
| self.freeze(pl_module.beta) | |
| print("CLUSTERING HAS BEEN FROOOZEN") | |
| def finetune_function(self, pl_module, current_epoch, optimizer): | |
| print("Not finetunning") | |
| def obtain_batch_numbers(x, g): | |
| dev = x.device | |
| graphs_eval = dgl.unbatch(g) | |
| number_graphs = len(graphs_eval) | |
| batch_numbers = [] | |
| for index in range(0, number_graphs): | |
| gj = graphs_eval[index] | |
| num_nodes = gj.number_of_nodes() | |
| batch_numbers.append(index * torch.ones(num_nodes).to(dev)) | |
| # num_nodes = gj.number_of_nodes() | |
| batch = torch.cat(batch_numbers, dim=0) | |
| return batch | |
| def obtain_clustering_for_matched_showers( | |
| batch_g, model_output, y_all, local_rank, use_gt_clusters=False, add_fakes=True | |
| ): | |
| graphs_showers_matched = [] | |
| graphs_showers_fakes = [] | |
| true_energy_showers = [] | |
| reco_energy_showers = [] | |
| reco_energy_showers_fakes = [] | |
| energy_true_daughters = [] | |
| y_pids_matched = [] | |
| y_coords_matched = [] | |
| if not use_gt_clusters: | |
| batch_g.ndata["coords"] = model_output[:, 0:3] | |
| batch_g.ndata["beta"] = model_output[:, 3] | |
| graphs = dgl.unbatch(batch_g) | |
| batch_id = y_all.batch_number | |
| for i in range(0, len(graphs)): | |
| mask = batch_id == i | |
| dic = {} | |
| dic["graph"] = graphs[i] | |
| y = y_all.copy() | |
| y.mask(mask.flatten()) | |
| dic["part_true"] = y | |
| if not use_gt_clusters: | |
| betas = torch.sigmoid(dic["graph"].ndata["beta"]) | |
| X = dic["graph"].ndata["coords"] | |
| if use_gt_clusters: | |
| labels = dic["graph"].ndata["particle_number"].type(torch.int64) | |
| else: | |
| labels =DPC_custom_CLD(X, dic["graph"], model_output.device) | |
| labels, _ = remove_bad_tracks_from_cluster(dic["graph"], labels) | |
| particle_ids = torch.unique(dic["graph"].ndata["particle_number"]) | |
| shower_p_unique = torch.unique(labels) | |
| shower_p_unique, row_ind, col_ind, i_m_w, _ = match_showers( | |
| labels, dic, particle_ids, model_output, local_rank, i, None | |
| ) | |
| row_ind = torch.Tensor(row_ind).to(model_output.device).long() | |
| col_ind = torch.Tensor(col_ind).to(model_output.device).long() | |
| if torch.sum(particle_ids == 0) > 0: | |
| row_ind_ = row_ind - 1 | |
| else: | |
| # if there is no zero then index 0 corresponds to particle 1. | |
| row_ind_ = row_ind | |
| index_matches = col_ind + 1 | |
| index_matches = index_matches.to(model_output.device).long() | |
| for j, unique_showers_label in enumerate(index_matches): | |
| if torch.sum(unique_showers_label == index_matches) == 1: | |
| index_in_matched = torch.argmax( | |
| (unique_showers_label == index_matches) * 1 | |
| ) | |
| mask = labels == unique_showers_label | |
| sls_graph = graphs[i].ndata["pos_hits_xyz"][mask][:, 0:3] | |
| g = dgl.graph(([], [])) | |
| g.add_nodes(sls_graph.shape[0]) | |
| g = g.to(sls_graph.device) | |
| g.ndata["h"] = graphs[i].ndata["h"][mask] | |
| if "pos_pxpypz" in graphs[i].ndata: | |
| g.ndata["pos_pxpypz"] = graphs[i].ndata["pos_pxpypz"][mask] | |
| if "pos_pxpypz_at_vertex" in graphs[i].ndata: | |
| g.ndata["pos_pxpypz_at_vertex"] = graphs[i].ndata[ | |
| "pos_pxpypz_at_vertex" | |
| ][mask] | |
| g.ndata["chi_squared_tracks"] = graphs[i].ndata["chi_squared_tracks"][mask] | |
| energy_t = dic["part_true"].E.to(model_output.device) | |
| energy_t_corr_daughters = dic["part_true"].m.to( | |
| model_output.device | |
| ) | |
| true_energy_shower = energy_t[row_ind_[j]] | |
| y_pids_matched.append(y.pid[row_ind_[j]].item()) | |
| y_coords_matched.append(y.coord[row_ind_[j]].detach().cpu().numpy()) | |
| energy_true_daughters.append(energy_t_corr_daughters[row_ind_[j]]) | |
| reco_energy_shower = torch.sum(graphs[i].ndata["e_hits"][mask]) | |
| graphs_showers_matched.append(g) | |
| true_energy_showers.append(true_energy_shower.view(-1)) | |
| reco_energy_showers.append(reco_energy_shower.view(-1)) | |
| pred_showers = shower_p_unique | |
| pred_showers[index_matches] = -1 | |
| pred_showers[ | |
| 0 | |
| ] = ( | |
| -1 | |
| ) | |
| mask_fakes = pred_showers != -1 | |
| fakes_idx = torch.where(mask_fakes)[0] | |
| if add_fakes: | |
| for j in fakes_idx: | |
| mask = labels == j | |
| sls_graph = graphs[i].ndata["pos_hits_xyz"][mask][:, 0:3] | |
| g = dgl.graph(([], [])) | |
| g.add_nodes(sls_graph.shape[0]) | |
| g = g.to(sls_graph.device) | |
| g.ndata["h"] = graphs[i].ndata["h"][mask] | |
| if "pos_pxpypz" in graphs[i].ndata: | |
| g.ndata["pos_pxpypz"] = graphs[i].ndata["pos_pxpypz"][mask] | |
| if "pos_pxpypz_at_vertex" in graphs[i].ndata: | |
| g.ndata["pos_pxpypz_at_vertex"] = graphs[i].ndata[ | |
| "pos_pxpypz_at_vertex" | |
| ][mask] | |
| g.ndata["chi_squared_tracks"] = graphs[i].ndata["chi_squared_tracks"][mask] | |
| graphs_showers_fakes.append(g) | |
| reco_energy_shower = torch.sum(graphs[i].ndata["e_hits"][mask]) | |
| reco_energy_showers_fakes.append(reco_energy_shower.view(-1)) | |
| graphs_showers_matched = dgl.batch(graphs_showers_matched + graphs_showers_fakes) | |
| true_energy_showers = torch.cat(true_energy_showers, dim=0) | |
| reco_energy_showers = torch.cat(reco_energy_showers + reco_energy_showers_fakes, dim=0) | |
| e_true_corr_daughters = torch.cat(energy_true_daughters, dim=0) | |
| number_of_fakes = len(reco_energy_showers_fakes) | |
| return ( | |
| graphs_showers_matched, | |
| true_energy_showers, | |
| reco_energy_showers, | |
| y_pids_matched, | |
| e_true_corr_daughters, | |
| y_coords_matched, | |
| number_of_fakes, | |
| fakes_idx | |
| ) | |