HitPF_demo / src /utils /post_clustering_features.py
github-actions[bot]
Sync from GitHub f6dbbfb
cc0720f
import torch
from torch_scatter import scatter_sum, scatter_std
def calculate_phi(x, y, z=None):
return torch.arctan2(y, x)
def calculate_eta(x, y, z):
theta = torch.arctan2(torch.sqrt(x ** 2 + y ** 2), z)
return -torch.log(torch.tan(theta / 2))
def get_post_clustering_features(graphs_new, sum_e):
'''
Obtain graph-level qualitative features that can then be used to regress the energy corr. factor.
:param graph_batch: Output from the previous step - clustered, matched showers
:return:
'''
batch_num_nodes = graphs_new.batch_num_nodes() # Num. of hits in each graph
batch_idx = []
for i, n in enumerate(batch_num_nodes):
batch_idx.extend([i] * n)
batch_idx = torch.tensor(batch_idx).to(graphs_new.device)
e_hits = graphs_new.ndata["h"][:, 8]
muon_hits = graphs_new.ndata["h"][:, 7]
filter_muon = torch.where(muon_hits)[0]
per_graph_e_hits_muon = scatter_sum(e_hits[filter_muon], batch_idx[filter_muon], dim_size=batch_idx.max() + 1)
per_graph_n_hits_muon = scatter_sum((e_hits[filter_muon] > 0).type(torch.int), batch_idx[filter_muon], dim_size=batch_idx.max() + 1)
ecal_hits = graphs_new.ndata["h"][:, 5]
filter_ecal = torch.where(ecal_hits)[0]
hcal_hits = graphs_new.ndata["h"][:, 6]
filter_hcal = torch.where(hcal_hits)[0]
per_graph_e_hits_ecal = scatter_sum(e_hits[filter_ecal], batch_idx[filter_ecal], dim_size=batch_idx.max() + 1)
# similar as above but with scatter_std
per_graph_e_hits_ecal_dispersion = scatter_std(e_hits[filter_ecal], batch_idx[filter_ecal], dim_size=batch_idx.max() + 1) ** 2
per_graph_e_hits_hcal = scatter_sum(e_hits[filter_hcal], batch_idx[filter_hcal], dim_size=batch_idx.max() + 1)
# similar as above but with scatter_std -- !!!!! TODO: Retrain the base EC models using this definition !!!!!
per_graph_e_hits_hcal_dispersion = scatter_std(e_hits[filter_hcal], batch_idx[filter_hcal], dim_size=batch_idx.max() + 1) ** 2
# track_nodes =
track_p = scatter_sum(graphs_new.ndata["h"][:, 9], batch_idx)
chis_tracks = scatter_sum(graphs_new.ndata["chi_squared_tracks"], batch_idx)
num_tracks = scatter_sum((graphs_new.ndata["h"][:, 9] > 0).type(torch.int), batch_idx)
track_p = track_p / num_tracks
track_p[num_tracks == 0] = 0.
chis_tracks = chis_tracks / num_tracks
num_hits = graphs_new.batch_num_nodes()
# print shapes of the below things
return torch.nan_to_num(
torch.stack([per_graph_e_hits_ecal / sum_e,
per_graph_e_hits_hcal / sum_e,
num_hits, track_p,
per_graph_e_hits_ecal_dispersion,
per_graph_e_hits_hcal_dispersion,
sum_e, num_tracks, torch.clamp(chis_tracks, -5, 5),
per_graph_e_hits_muon,
per_graph_n_hits_muon
]).T
)
def get_extra_features(graphs_new, betas):
'''
Obtain extra graph-level features for debugging of the fakes
'''
batch_num_nodes = graphs_new.batch_num_nodes() # Num. of hits in each graph
batch_idx = []
topk_highest_betas = []
for i, n in enumerate(batch_num_nodes):
batch_idx.extend([i] * n)
batch_idx = torch.tensor(batch_idx).to(graphs_new.device)
n_highest_betas = 1
for i in range(len(batch_num_nodes)):
betas_i = betas[batch_idx == i]
topk_betas = torch.topk(betas_i, n_highest_betas)
if len(topk_betas.values) < n_highest_betas:
topk_betas = torch.cat([topk_betas.values, torch.zeros(n_highest_betas - len(topk_betas.values))])
topk_highest_betas.append(topk_betas.values)
topk_highest_betas = torch.stack(topk_highest_betas)
# Concat with batch_num_nodes
features = torch.cat([batch_num_nodes.view(-1, 1), topk_highest_betas], dim=1)
return features