Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import dgl | |
| from src.dataset.functions_data import ( | |
| calculate_distance_to_boundary, | |
| ) | |
| import time | |
| from src.dataset.functions_particles import concatenate_Particles_GT, Particles_GT | |
| from src.dataset.dataclasses import Hits | |
| def create_inputs_from_table( | |
| output, prediction=False, args=None | |
| ): | |
| number_hits = np.int32(len(output["X_track"])+len(output["X_hit"])) | |
| number_part = np.int32(len(output["X_gen"])) | |
| hits = Hits.from_data( | |
| output, | |
| number_hits, | |
| args, | |
| number_part | |
| ) | |
| y_data_graph = Particles_GT() | |
| y_data_graph.fill( output, prediction,args) | |
| result = [ | |
| y_data_graph, | |
| hits | |
| ] | |
| return result | |
| def create_graph( | |
| output, | |
| for_training =True, args=None | |
| ): | |
| prediction = not for_training | |
| graph_empty = False | |
| result = create_inputs_from_table( | |
| output, | |
| prediction=prediction, | |
| args=args | |
| ) | |
| if len(result) == 1: | |
| graph_empty = True | |
| return [0, 0], graph_empty | |
| else: | |
| (y_data_graph,hits) = result | |
| g = dgl.graph(([], [])) | |
| g.add_nodes(hits.pos_xyz_hits.shape[0]) | |
| g.ndata["h"] = torch.cat( | |
| (hits.pos_xyz_hits, hits.hit_type_one_hot, hits.e_hits, hits.p_hits), dim=1 | |
| ).float() | |
| g.ndata["p_hits"] = hits.p_hits.float() | |
| g.ndata["pos_hits_xyz"] = hits.pos_xyz_hits.float() | |
| g.ndata["pos_pxpypz_at_vertex"] = hits.pos_pxpypz.float() | |
| g.ndata["pos_pxpypz"] = hits.pos_pxpypz #TrackState::AtIP | |
| g.ndata["pos_pxpypz_at_calo"] = hits.pos_pxpypz_calo #TrackState::AtCalorimeter | |
| g = calculate_distance_to_boundary(g) | |
| g.ndata["hit_type"] = hits.hit_type_feature.float() | |
| g.ndata["e_hits"] = hits.e_hits.float() | |
| g.ndata["chi_squared_tracks"] = hits.chi_squared_tracks.float() | |
| g.ndata["particle_number"] = hits.hit_particle_link.float()+1 #(noise idx is 0 and particle MC 0 starts at 1) | |
| if prediction and (args.pandora): | |
| g.ndata["pandora_pfo"] = hits.pandora_features.pandora_pfo_link.float() | |
| g.ndata["pandora_pfo_energy"] = hits.pandora_features.pfo_energy.float() | |
| g.ndata["pandora_momentum"] = hits.pandora_features.pandora_mom_components.float() | |
| g.ndata["pandora_reference_point"] = hits.pandora_features.pandora_ref_point.float() | |
| g.ndata["pandora_pid"] = hits.pandora_features.pandora_pid.float() | |
| graph_empty = False | |
| unique_links = torch.unique(hits.hit_particle_link) | |
| if not prediction and unique_links.shape[0] == 1 and unique_links[0] == -1: | |
| graph_empty = True | |
| if hits.pos_xyz_hits.shape[0] < 10: | |
| graph_empty = True | |
| return [g, y_data_graph], graph_empty | |
| def graph_batch_func(list_graphs): | |
| """collator function for graph dataloader | |
| Args: | |
| list_graphs (list): list of graphs from the iterable dataset | |
| Returns: | |
| batch dgl: dgl batch of graphs | |
| """ | |
| list_graphs_g = [el[0] for el in list_graphs] | |
| ys = concatenate_Particles_GT(list_graphs) | |
| bg = dgl.batch(list_graphs_g) | |
| return bg, ys | |