Spaces:
Sleeping
Sleeping
| import torch | |
| from torch_scatter import scatter_sum, scatter_std | |
| def calculate_phi(x, y, z=None): | |
| return torch.arctan2(y, x) | |
| def calculate_eta(x, y, z): | |
| theta = torch.arctan2(torch.sqrt(x ** 2 + y ** 2), z) | |
| return -torch.log(torch.tan(theta / 2)) | |
| def get_post_clustering_features(graphs_new, sum_e): | |
| ''' | |
| Obtain graph-level qualitative features that can then be used to regress the energy corr. factor. | |
| :param graph_batch: Output from the previous step - clustered, matched showers | |
| :return: | |
| ''' | |
| batch_num_nodes = graphs_new.batch_num_nodes() # Num. of hits in each graph | |
| batch_idx = [] | |
| for i, n in enumerate(batch_num_nodes): | |
| batch_idx.extend([i] * n) | |
| batch_idx = torch.tensor(batch_idx).to(graphs_new.device) | |
| e_hits = graphs_new.ndata["h"][:, 8] | |
| muon_hits = graphs_new.ndata["h"][:, 7] | |
| filter_muon = torch.where(muon_hits)[0] | |
| per_graph_e_hits_muon = scatter_sum(e_hits[filter_muon], batch_idx[filter_muon], dim_size=batch_idx.max() + 1) | |
| per_graph_n_hits_muon = scatter_sum((e_hits[filter_muon] > 0).type(torch.int), batch_idx[filter_muon], dim_size=batch_idx.max() + 1) | |
| ecal_hits = graphs_new.ndata["h"][:, 5] | |
| filter_ecal = torch.where(ecal_hits)[0] | |
| hcal_hits = graphs_new.ndata["h"][:, 6] | |
| filter_hcal = torch.where(hcal_hits)[0] | |
| per_graph_e_hits_ecal = scatter_sum(e_hits[filter_ecal], batch_idx[filter_ecal], dim_size=batch_idx.max() + 1) | |
| # similar as above but with scatter_std | |
| per_graph_e_hits_ecal_dispersion = scatter_std(e_hits[filter_ecal], batch_idx[filter_ecal], dim_size=batch_idx.max() + 1) ** 2 | |
| per_graph_e_hits_hcal = scatter_sum(e_hits[filter_hcal], batch_idx[filter_hcal], dim_size=batch_idx.max() + 1) | |
| # similar as above but with scatter_std -- !!!!! TODO: Retrain the base EC models using this definition !!!!! | |
| per_graph_e_hits_hcal_dispersion = scatter_std(e_hits[filter_hcal], batch_idx[filter_hcal], dim_size=batch_idx.max() + 1) ** 2 | |
| # track_nodes = | |
| track_p = scatter_sum(graphs_new.ndata["h"][:, 9], batch_idx) | |
| chis_tracks = scatter_sum(graphs_new.ndata["chi_squared_tracks"], batch_idx) | |
| num_tracks = scatter_sum((graphs_new.ndata["h"][:, 9] > 0).type(torch.int), batch_idx) | |
| track_p = track_p / num_tracks | |
| track_p[num_tracks == 0] = 0. | |
| chis_tracks = chis_tracks / num_tracks | |
| num_hits = graphs_new.batch_num_nodes() | |
| # print shapes of the below things | |
| return torch.nan_to_num( | |
| torch.stack([per_graph_e_hits_ecal / sum_e, | |
| per_graph_e_hits_hcal / sum_e, | |
| num_hits, track_p, | |
| per_graph_e_hits_ecal_dispersion, | |
| per_graph_e_hits_hcal_dispersion, | |
| sum_e, num_tracks, torch.clamp(chis_tracks, -5, 5), | |
| per_graph_e_hits_muon, | |
| per_graph_n_hits_muon | |
| ]).T | |
| ) | |
| def get_extra_features(graphs_new, betas): | |
| ''' | |
| Obtain extra graph-level features for debugging of the fakes | |
| ''' | |
| batch_num_nodes = graphs_new.batch_num_nodes() # Num. of hits in each graph | |
| batch_idx = [] | |
| topk_highest_betas = [] | |
| for i, n in enumerate(batch_num_nodes): | |
| batch_idx.extend([i] * n) | |
| batch_idx = torch.tensor(batch_idx).to(graphs_new.device) | |
| n_highest_betas = 1 | |
| for i in range(len(batch_num_nodes)): | |
| betas_i = betas[batch_idx == i] | |
| topk_betas = torch.topk(betas_i, n_highest_betas) | |
| if len(topk_betas.values) < n_highest_betas: | |
| topk_betas = torch.cat([topk_betas.values, torch.zeros(n_highest_betas - len(topk_betas.values))]) | |
| topk_highest_betas.append(topk_betas.values) | |
| topk_highest_betas = torch.stack(topk_highest_betas) | |
| # Concat with batch_num_nodes | |
| features = torch.cat([batch_num_nodes.view(-1, 1), topk_highest_betas], dim=1) | |
| return features | |