HitPF_demo / src /dataset /dataclasses.py
github-actions[bot]
Sync from GitHub f6dbbfb
cc0720f
from dataclasses import dataclass
from typing import Any, List, Optional
import torch
import numpy as np
@dataclass
class PandoraFeatures:
# Features associated to the hits
pandora_cluster: Optional[Any] = None
pandora_cluster_energy: Optional[Any] = None
pfo_energy: Optional[Any] = None
pandora_mom: Optional[Any] = None
pandora_ref_point: Optional[Any] = None
pandora_pid: Optional[Any] = None
pandora_pfo_link: Optional[Any] = None
pandora_mom_components: Optional[Any] = None
@dataclass
class Hits:
pos_xyz_hits: Any
pos_pxpypz: Any
pos_pxpypz_calo: Any
p_hits: Any
e_hits: Any
hit_particle_link: Any
pandora_features: Any # type PandoraFeatures
hit_type_feature: Any
chi_squared_tracks: Any
hit_type_one_hot: Any
@classmethod
def from_data(cls, output, number_hits, args, number_part):
hit_particle_link_hits = torch.tensor(output["ygen_hit"])
if len(output["ygen_track"])>0:
hit_particle_link_tracks= torch.tensor(output["ygen_track"])
hit_particle_link = torch.cat((hit_particle_link_hits, hit_particle_link_tracks), dim=0)
else:
hit_particle_link = hit_particle_link_hits
# hit_particle_link_calomother = torch.cat((hit_particle_link_hits_calomother, hit_particle_link_tracks), dim=0)
if args.pandora:
pandora_features = PandoraFeatures()
X_pandora = torch.tensor(output["X_pandora"])
pfo_link_hits = torch.tensor(output["pfo_calohit"])
if len(output["pfo_track"])>0:
pfo_link_tracks = torch.tensor(output["pfo_track"])
pfo_link = torch.cat((pfo_link_hits, pfo_link_tracks), dim=0)
else:
pfo_link = pfo_link_hits
pandora_features.pandora_pfo_link = pfo_link
pfo_link_temp = pfo_link.clone()
pfo_link_temp[pfo_link_temp==-1]=0
pandora_features.pandora_mom = X_pandora[pfo_link_temp, 8]
pandora_features.pandora_ref_point = X_pandora[pfo_link_temp, 4:7]
pandora_features.pandora_mom_components = X_pandora[pfo_link_temp, 1:4]
pandora_features.pandora_pid = X_pandora[pfo_link_temp, 0]
pandora_features.pfo_energy = X_pandora[pfo_link_temp, 7]
pandora_features.pandora_mom[pfo_link==-1]=0
pandora_features.pandora_mom_components[pfo_link==-1]=0
pandora_features.pandora_ref_point[pfo_link==-1]=0
pandora_features.pandora_pid[pfo_link==-1]=0
pandora_features.pfo_energy[pfo_link==-1]=0
else:
pandora_features = None
X_hit = torch.tensor(output["X_hit"])
if len(output["X_track"])>0:
X_track = torch.tensor(output["X_track"])
# obtain hit type
hit_type_feature_hit = X_hit[:,10]+1 #tyep (1,2,3,4 hits)
if len(output["X_track"])>0:
hit_type_feature_track = X_track[:,0] #elemtype (1 for tracks)
hit_type_feature = torch.cat((hit_type_feature_hit, hit_type_feature_track), dim=0).to(torch.int64)
else:
hit_type_feature = hit_type_feature_hit.to(torch.int64)
# obtain the position of the hits and the energies and p
pos_xyz_hits_hits = X_hit[:,6:9]
e_hits = X_hit[:,5]
p_hits = X_hit[:,5]*0
if len(output["X_track"])>0:
pos_xyz_hits_tracks = X_track[:,12:15] #(referencePoint_calo.i)
pos_xyz_hits = torch.cat((pos_xyz_hits_hits, pos_xyz_hits_tracks), dim=0)
e_tracks =X_track[:,5]*0
e = torch.cat((e_hits, e_tracks), dim=0).view(-1,1)
p_tracks =X_track[:,5]
pos_pxpypz_hits_tracks = X_track[:,6:9]
pos_pxpypz = torch.cat((pos_xyz_hits_hits*0, pos_pxpypz_hits_tracks), dim=0)
pos_pxpypz_hits_tracks = X_track[:,22:]
pos_pxpypz_calo = torch.cat((pos_xyz_hits_hits*0, pos_pxpypz_hits_tracks), dim=0)
p = torch.cat((p_hits, p_tracks), dim=0).view(-1,1)
else:
pos_xyz_hits = pos_xyz_hits_hits
e = e_hits.view(-1,1)
pos_pxpypz = pos_xyz_hits_hits*0
pos_pxpypz_calo = pos_pxpypz
p = p_hits.view(-1,1)
if len(output["X_track"])>0:
chi_tracks = X_track[:,15]/ X_track[:,16]
chi_squared_tracks = torch.cat((p_hits, chi_tracks), dim=0)
else:
chi_squared_tracks = p_hits
hit_type_one_hot = torch.nn.functional.one_hot(
hit_type_feature, num_classes=5
)
return cls(
pos_xyz_hits=pos_xyz_hits,
pos_pxpypz=pos_pxpypz,
pos_pxpypz_calo = pos_pxpypz_calo,
p_hits=p,
e_hits=e,
hit_particle_link=hit_particle_link,
pandora_features= pandora_features,
hit_type_feature=hit_type_feature,
chi_squared_tracks=chi_squared_tracks,
hit_type_one_hot = hit_type_one_hot,
)