File size: 3,180 Bytes
cc0720f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
import torch
import dgl
from src.dataset.functions_data import (
    calculate_distance_to_boundary,
)
import time
from src.dataset.functions_particles import concatenate_Particles_GT, Particles_GT

from src.dataset.dataclasses import Hits

def create_inputs_from_table(
    output, prediction=False, args=None
):
    number_hits = np.int32(len(output["X_track"])+len(output["X_hit"]))
    number_part = np.int32(len(output["X_gen"]))

    hits = Hits.from_data(
    output,
    number_hits,
    args,
    number_part
    )

    y_data_graph = Particles_GT()
    y_data_graph.fill( output, prediction,args)

    result = [
        y_data_graph,  
        hits
    ]
    return result
 



def create_graph(
    output,
    for_training =True, args=None
):
    prediction = not for_training
    graph_empty = False
   
    result = create_inputs_from_table(
        output,
        prediction=prediction, 
        args=args
    )

    if len(result) == 1:
        graph_empty = True
        return [0, 0], graph_empty
    else:
        (y_data_graph,hits) = result

        g = dgl.graph(([], []))
        g.add_nodes(hits.pos_xyz_hits.shape[0])
        g.ndata["h"] = torch.cat(
                (hits.pos_xyz_hits, hits.hit_type_one_hot, hits.e_hits, hits.p_hits), dim=1
            ).float()  
        g.ndata["p_hits"] = hits.p_hits.float() 
        g.ndata["pos_hits_xyz"] = hits.pos_xyz_hits.float()
        g.ndata["pos_pxpypz_at_vertex"] = hits.pos_pxpypz.float()
        g.ndata["pos_pxpypz"] = hits.pos_pxpypz  #TrackState::AtIP
        g.ndata["pos_pxpypz_at_calo"] = hits.pos_pxpypz_calo  #TrackState::AtCalorimeter
        g = calculate_distance_to_boundary(g)
        g.ndata["hit_type"] = hits.hit_type_feature.float()
        g.ndata["e_hits"] = hits.e_hits.float()  
 
        g.ndata["chi_squared_tracks"] = hits.chi_squared_tracks.float()
        g.ndata["particle_number"] = hits.hit_particle_link.float()+1 #(noise idx is 0 and particle MC 0 starts at 1)
 

        if prediction and (args.pandora):
            g.ndata["pandora_pfo"] = hits.pandora_features.pandora_pfo_link.float()
            g.ndata["pandora_pfo_energy"] = hits.pandora_features.pfo_energy.float()
            g.ndata["pandora_momentum"] = hits.pandora_features.pandora_mom_components.float()
            g.ndata["pandora_reference_point"] = hits.pandora_features.pandora_ref_point.float()
            g.ndata["pandora_pid"] = hits.pandora_features.pandora_pid.float()
        graph_empty = False
        unique_links = torch.unique(hits.hit_particle_link)
        if not prediction and unique_links.shape[0] == 1 and unique_links[0] == -1:
            graph_empty = True 
        if hits.pos_xyz_hits.shape[0] < 10:
            graph_empty = True
    
    return [g, y_data_graph], graph_empty





def graph_batch_func(list_graphs):
    """collator function for graph dataloader

    Args:
        list_graphs (list): list of graphs from the iterable dataset

    Returns:
        batch dgl: dgl batch of graphs
    """
    list_graphs_g = [el[0] for el in list_graphs]
    ys = concatenate_Particles_GT(list_graphs)
    bg = dgl.batch(list_graphs_g)
    return bg, ys