Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from torch_scatter import scatter_add, scatter_sum | |
| from sklearn.preprocessing import StandardScaler | |
| from torch_scatter import scatter_sum | |
| from src.dataset.functions_data import ( | |
| get_ratios, | |
| find_mask_no_energy, | |
| find_cluster_id, | |
| get_particle_features, | |
| get_hit_features, | |
| calculate_distance_to_boundary, | |
| concatenate_Particles_GT, | |
| ) | |
| def create_inputs_from_table( | |
| output, hits_only, prediction=False, hit_chis=False, pos_pxpy=False, is_Ks=False | |
| ): | |
| """Used by graph creation to get nodes and edge features | |
| Args: | |
| output (_type_): input from the root reading | |
| hits_only (_type_): reading only hits or also tracks | |
| prediction (bool, optional): if running in eval mode. Defaults to False. | |
| Returns: | |
| _type_: all information to construct a graph | |
| """ | |
| number_hits = np.int32(np.sum(output["pf_mask"][0])) | |
| number_part = np.int32(np.sum(output["pf_mask"][1])) | |
| ( | |
| pos_xyz_hits, | |
| pos_pxpypz, | |
| p_hits, | |
| e_hits, | |
| hit_particle_link, | |
| pandora_cluster, | |
| pandora_cluster_energy, | |
| pfo_energy, | |
| pandora_mom, | |
| pandora_ref_point, | |
| unique_list_particles, | |
| cluster_id, | |
| hit_type_feature, | |
| pandora_pfo_link, | |
| daughters, | |
| hit_link_modified, | |
| connection_list, | |
| chi_squared_tracks, | |
| ) = get_hit_features( | |
| output, | |
| number_hits, | |
| prediction, | |
| number_part, | |
| hit_chis=hit_chis, | |
| pos_pxpy=pos_pxpy, | |
| is_Ks=is_Ks, | |
| ) | |
| # features particles | |
| y_data_graph = get_particle_features( | |
| unique_list_particles, output, prediction, connection_list | |
| ) | |
| assert len(y_data_graph) == len(unique_list_particles) | |
| # remove particles that have no energy, no hits or only track hits | |
| mask_hits, mask_particles = find_mask_no_energy( | |
| cluster_id, | |
| hit_type_feature, | |
| e_hits, | |
| y_data_graph, | |
| daughters, | |
| prediction, | |
| is_Ks=is_Ks, | |
| ) | |
| # create mapping from links to number of particles in the event | |
| cluster_id, unique_list_particles = find_cluster_id(hit_particle_link[~mask_hits]) | |
| y_data_graph.mask(~mask_particles) | |
| if prediction: | |
| if is_Ks: | |
| result = [ | |
| y_data_graph, # y_data_graph[~mask_particles], | |
| p_hits[~mask_hits], | |
| e_hits[~mask_hits], | |
| cluster_id, | |
| hit_particle_link[~mask_hits], | |
| pos_xyz_hits[~mask_hits], | |
| pos_pxpypz[~mask_hits], | |
| pandora_cluster[~mask_hits], | |
| pandora_cluster_energy[~mask_hits], | |
| pandora_mom[~mask_hits], | |
| pandora_ref_point[~mask_hits], | |
| pfo_energy[~mask_hits], | |
| pandora_pfo_link[~mask_hits], | |
| hit_type_feature[~mask_hits], | |
| hit_link_modified[~mask_hits], | |
| daughters[~mask_hits] | |
| ] | |
| else: | |
| result = [ | |
| y_data_graph, # y_data_graph[~mask_particles], | |
| p_hits[~mask_hits], | |
| e_hits[~mask_hits], | |
| cluster_id, | |
| hit_particle_link[~mask_hits], | |
| pos_xyz_hits[~mask_hits], | |
| pos_pxpypz[~mask_hits], | |
| pandora_cluster[~mask_hits], | |
| pandora_cluster_energy[~mask_hits], | |
| pandora_mom, | |
| pandora_ref_point, | |
| pfo_energy[~mask_hits], | |
| pandora_pfo_link[~mask_hits], | |
| hit_type_feature[~mask_hits], | |
| hit_link_modified[~mask_hits], | |
| ] | |
| else: | |
| result = [ | |
| y_data_graph, # y_data_graph[~mask_particles], | |
| p_hits[~mask_hits], | |
| e_hits[~mask_hits], | |
| cluster_id, | |
| hit_particle_link[~mask_hits], | |
| pos_xyz_hits[~mask_hits], | |
| pos_pxpypz[~mask_hits], | |
| pandora_cluster, | |
| pandora_cluster_energy, | |
| pandora_mom, | |
| pandora_ref_point, | |
| pfo_energy, | |
| pandora_pfo_link, | |
| hit_type_feature[~mask_hits], | |
| hit_link_modified[~mask_hits], | |
| daughters[~mask_hits] | |
| ] | |
| if hit_chis: | |
| result.append( | |
| chi_squared_tracks[~mask_hits], | |
| ) | |
| else: | |
| result.append(None) | |
| hit_type = hit_type_feature[~mask_hits] | |
| # if hits only remove tracks, otherwise leave tracks | |
| if hits_only: | |
| hit_mask = (hit_type == 0) | (hit_type == 1) | |
| hit_mask = ~hit_mask | |
| for i in range(1, len(result)): | |
| if result[i] is not None: | |
| result[i] = result[i][hit_mask] | |
| hit_type_one_hot = torch.nn.functional.one_hot( | |
| hit_type_feature[~mask_hits][hit_mask] - 2, num_classes=2 | |
| ) | |
| else: | |
| # if we want the tracks keep only 1 track hit per charged particle. | |
| hit_mask = hit_type == 10 | |
| hit_mask = ~hit_mask | |
| for i in range(1, len(result)): | |
| if result[i] is not None: | |
| # if len(result[i].shape) == 2 and result[i].shape[0] == 3: | |
| # result[i] = result[i][:, hit_mask] | |
| # else: | |
| # result[i] = result[i][hit_mask] | |
| result[i] = result[i][hit_mask] | |
| hit_type_one_hot = torch.nn.functional.one_hot( | |
| hit_type_feature[~mask_hits][hit_mask], num_classes=5 | |
| ) | |
| result.append(hit_type_one_hot) | |
| result.append(connection_list) | |
| return result | |
| def remove_hittype0(graph): | |
| filt = graph.ndata["hit_type"] == 0 | |
| # graph.ndata["hit_type"] -= 1 | |
| return dgl.remove_nodes(graph, torch.where(filt)[0]) | |
| def store_track_at_vertex_at_track_at_calo(graph): | |
| # To make it compatible with clustering, remove the 0 hit type nodes and store them as pos_pxpypz_at_vertex | |
| tracks_at_calo = graph.ndata["hit_type"] == 1 | |
| tracks_at_vertex = graph.ndata["hit_type"] == 0 | |
| part = graph.ndata["particle_number"].long() | |
| assert (part[tracks_at_calo] == part[tracks_at_vertex]).all() | |
| graph.ndata["pos_pxpypz_at_vertex"] = torch.zeros_like(graph.ndata["pos_pxpypz"]) | |
| graph.ndata["pos_pxpypz_at_vertex"][tracks_at_calo] = graph.ndata["pos_pxpypz"][tracks_at_vertex] | |
| return remove_hittype0(graph) | |
| def create_graph( | |
| output, | |
| config=None, | |
| n_noise=0, | |
| ): | |
| ks_dataset = np.float32(np.sum(output["pf_mask"][2])) | |
| hits_only = config.graph_config.get( | |
| "only_hits", False | |
| ) # Whether to only include hits in the graph | |
| # standardize_coords = config.graph_config.get("standardize_coords", False) | |
| extended_coords = config.graph_config.get("extended_coords", False) | |
| prediction = config.graph_config.get("prediction", False) | |
| hit_chis = config.graph_config.get("hit_chis_track", False) | |
| pos_pxpy = config.graph_config.get("pos_pxpy", False) | |
| is_Ks = (torch.sum(torch.Tensor([ks_dataset])))!=0 #config.graph_config.get("ks", False) | |
| ( | |
| y_data_graph, | |
| p_hits, | |
| e_hits, | |
| cluster_id, | |
| hit_particle_link, | |
| pos_xyz_hits, | |
| pos_pxpypz, | |
| pandora_cluster, | |
| pandora_cluster_energy, | |
| pandora_mom, | |
| pandora_ref_point, | |
| pandora_pfo_energy, | |
| pandora_pfo_link, | |
| hit_type, | |
| hit_link_modified, | |
| daugthers, | |
| chi_squared_tracks, | |
| hit_type_one_hot, | |
| connections_list, | |
| ) = create_inputs_from_table( | |
| output, | |
| hits_only=hits_only, | |
| prediction=prediction, | |
| hit_chis=hit_chis, | |
| pos_pxpy=pos_pxpy, | |
| is_Ks=is_Ks, | |
| ) | |
| graph_coordinates = pos_xyz_hits # / 3330 # divide by detector size | |
| if pos_xyz_hits.shape[0] > 0: | |
| graph_empty = False | |
| g = dgl.graph(([], [])) | |
| g.add_nodes(graph_coordinates.shape[0]) | |
| if hits_only == False: | |
| hit_features_graph = torch.cat( | |
| (graph_coordinates, hit_type_one_hot, e_hits, p_hits), dim=1 | |
| ) # dims = 8 | |
| else: | |
| hit_features_graph = torch.cat( | |
| (graph_coordinates, hit_type_one_hot, e_hits, p_hits), dim=1 | |
| ) # dims = 9 | |
| g.ndata["h"] = hit_features_graph | |
| g.ndata["pos_hits_xyz"] = pos_xyz_hits | |
| g.ndata["pos_pxpypz"] = pos_pxpypz | |
| g = calculate_distance_to_boundary(g) | |
| g.ndata["hit_type"] = hit_type | |
| g.ndata[ | |
| "e_hits" | |
| ] = e_hits # if no tracks this is e and if there are tracks this fills the tracks e values with p | |
| if hit_chis: | |
| g.ndata["chi_squared_tracks"] = chi_squared_tracks | |
| g.ndata["particle_number"] = cluster_id | |
| g.ndata["hit_link_modified"] = hit_link_modified | |
| g.ndata["daugthers"] = daugthers | |
| g.ndata["particle_number_nomap"] = hit_particle_link | |
| if prediction: | |
| g.ndata["pandora_cluster"] = pandora_cluster | |
| g.ndata["pandora_pfo"] = pandora_pfo_link | |
| g.ndata["pandora_cluster_energy"] = pandora_cluster_energy | |
| g.ndata["pandora_pfo_energy"] = pandora_pfo_energy | |
| if is_Ks: | |
| g.ndata["pandora_momentum"] = pandora_mom | |
| g.ndata["pandora_reference_point"] = pandora_ref_point | |
| y_data_graph.calculate_corrected_E(g, connections_list) | |
| if ks_dataset>0: #is_Ks == True: | |
| if y_data_graph.pid.flatten().shape[0] == 4 and np.count_nonzero(y_data_graph.pid.flatten() == 22) == 4: | |
| graph_empty = False | |
| else: | |
| graph_empty = True | |
| if g.ndata["h"].shape[0] < 10 or (set(g.ndata["hit_type"].unique().tolist()) == set([0, 1]) and g.ndata["hit_type"][g.ndata["hit_type"] == 1].shape[0] < 10): | |
| graph_empty = True # less than 10 hits | |
| if is_Ks == False: | |
| if len(y_data_graph) < 4: | |
| graph_empty = True | |
| else: | |
| graph_empty = True | |
| g = 0 | |
| y_data_graph = 0 | |
| if pos_xyz_hits.shape[0] < 10: | |
| graph_empty = True | |
| print("graph_empty", graph_empty, pos_xyz_hits.shape[0]) | |
| if graph_empty: | |
| return [g, y_data_graph], graph_empty | |
| return [store_track_at_vertex_at_track_at_calo(g), y_data_graph], graph_empty | |
| def graph_batch_func(list_graphs): | |
| """collator function for graph dataloader | |
| Args: | |
| list_graphs (list): list of graphs from the iterable dataset | |
| Returns: | |
| batch dgl: dgl batch of graphs | |
| """ | |
| list_graphs_g = [el[0] for el in list_graphs] | |
| # list_y = add_batch_number(list_graphs) | |
| # ys = torch.cat(list_y, dim=0) | |
| # ys = torch.reshape(ys, [-1, list_y[0].shape[1]]) | |
| ys = concatenate_Particles_GT(list_graphs) | |
| bg = dgl.batch(list_graphs_g) | |
| # reindex particle number | |
| return bg, ys | |