Spaces:
Sleeping
Sleeping
File size: 3,180 Bytes
cc0720f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | import numpy as np
import torch
import dgl
from src.dataset.functions_data import (
calculate_distance_to_boundary,
)
import time
from src.dataset.functions_particles import concatenate_Particles_GT, Particles_GT
from src.dataset.dataclasses import Hits
def create_inputs_from_table(
output, prediction=False, args=None
):
number_hits = np.int32(len(output["X_track"])+len(output["X_hit"]))
number_part = np.int32(len(output["X_gen"]))
hits = Hits.from_data(
output,
number_hits,
args,
number_part
)
y_data_graph = Particles_GT()
y_data_graph.fill( output, prediction,args)
result = [
y_data_graph,
hits
]
return result
def create_graph(
output,
for_training =True, args=None
):
prediction = not for_training
graph_empty = False
result = create_inputs_from_table(
output,
prediction=prediction,
args=args
)
if len(result) == 1:
graph_empty = True
return [0, 0], graph_empty
else:
(y_data_graph,hits) = result
g = dgl.graph(([], []))
g.add_nodes(hits.pos_xyz_hits.shape[0])
g.ndata["h"] = torch.cat(
(hits.pos_xyz_hits, hits.hit_type_one_hot, hits.e_hits, hits.p_hits), dim=1
).float()
g.ndata["p_hits"] = hits.p_hits.float()
g.ndata["pos_hits_xyz"] = hits.pos_xyz_hits.float()
g.ndata["pos_pxpypz_at_vertex"] = hits.pos_pxpypz.float()
g.ndata["pos_pxpypz"] = hits.pos_pxpypz #TrackState::AtIP
g.ndata["pos_pxpypz_at_calo"] = hits.pos_pxpypz_calo #TrackState::AtCalorimeter
g = calculate_distance_to_boundary(g)
g.ndata["hit_type"] = hits.hit_type_feature.float()
g.ndata["e_hits"] = hits.e_hits.float()
g.ndata["chi_squared_tracks"] = hits.chi_squared_tracks.float()
g.ndata["particle_number"] = hits.hit_particle_link.float()+1 #(noise idx is 0 and particle MC 0 starts at 1)
if prediction and (args.pandora):
g.ndata["pandora_pfo"] = hits.pandora_features.pandora_pfo_link.float()
g.ndata["pandora_pfo_energy"] = hits.pandora_features.pfo_energy.float()
g.ndata["pandora_momentum"] = hits.pandora_features.pandora_mom_components.float()
g.ndata["pandora_reference_point"] = hits.pandora_features.pandora_ref_point.float()
g.ndata["pandora_pid"] = hits.pandora_features.pandora_pid.float()
graph_empty = False
unique_links = torch.unique(hits.hit_particle_link)
if not prediction and unique_links.shape[0] == 1 and unique_links[0] == -1:
graph_empty = True
if hits.pos_xyz_hits.shape[0] < 10:
graph_empty = True
return [g, y_data_graph], graph_empty
def graph_batch_func(list_graphs):
"""collator function for graph dataloader
Args:
list_graphs (list): list of graphs from the iterable dataset
Returns:
batch dgl: dgl batch of graphs
"""
list_graphs_g = [el[0] for el in list_graphs]
ys = concatenate_Particles_GT(list_graphs)
bg = dgl.batch(list_graphs_g)
return bg, ys
|