Spaces:
Sleeping
Sleeping
| from typing import Tuple, Union | |
| import numpy as np | |
| import torch | |
| from torch_scatter import scatter_max, scatter_add, scatter_mean | |
| from src.layers.object_cond import assert_no_nans, scatter_count, batch_cluster_indices | |
| import dgl | |
| def calc_energy_loss( | |
| batch, cluster_space_coords, beta, beta_stabilizing="soft_q_scaling", qmin=0.1, radius=0.7, | |
| e_frac_loss_return_particles=False, y=None, select_centers_by_particle=True | |
| ): | |
| # select_centers_by_particle: if True, we pretend we know which hits belong to which particle... | |
| list_graphs = dgl.unbatch(batch) | |
| node_counter = 0 | |
| if beta_stabilizing == "paper": | |
| q = beta.arctanh() ** 2 + qmin | |
| elif beta_stabilizing == "clip": | |
| beta = beta.clip(0.0, 1 - 1e-4) | |
| q = beta.arctanh() ** 2 + qmin | |
| elif beta_stabilizing == "soft_q_scaling": | |
| q = (beta.clip(0.0, 1 - 1e-4) / 1.002).arctanh() ** 2 + qmin | |
| else: | |
| raise ValueError(f"beta_stablizing mode {beta_stabilizing} is not known") | |
| loss_E_frac = [] | |
| loss_E_frac_true = [] | |
| particle_ids_all = [] | |
| reco_count = {} # per-PID count | |
| non_reco_count = {} | |
| total_count = {} | |
| for g in list_graphs: | |
| particle_id = g.ndata["particle_number"] | |
| number_of_objects = len(particle_id.unique()) | |
| print("No. of objects", number_of_objects) | |
| non = g.number_of_nodes() | |
| q_g = q[node_counter : non + node_counter] | |
| betas = beta[node_counter : non + node_counter] | |
| sorted, indices = torch.sort(betas.view(-1), descending=False) | |
| selected_centers = indices[0:number_of_objects] | |
| _, selected_centers_particles = scatter_max( | |
| betas.flatten().cpu(), particle_id.cpu().long() - 1 | |
| ) | |
| assert selected_centers.shape[0] == number_of_objects | |
| if select_centers_by_particle: | |
| selected_centers = selected_centers_particles.to(g.device) | |
| all_particles = set((particle_id.unique()-1).long().tolist()) | |
| reco_particles = set((particle_id[selected_centers]-1).long().tolist()) | |
| non_reco_particles = all_particles - reco_particles | |
| part_pids = y[:, 6].long() | |
| for particle in all_particles: | |
| curr_pid = part_pids[particle].item() | |
| if curr_pid in total_count: | |
| total_count[curr_pid] += 1 | |
| else: | |
| total_count[curr_pid] = 1 | |
| if particle in reco_particles: | |
| if curr_pid in reco_count: | |
| reco_count[curr_pid] += 1 | |
| else: | |
| reco_count[curr_pid] = 1 | |
| else: | |
| if curr_pid in non_reco_count: | |
| non_reco_count[curr_pid] += 1 | |
| else: | |
| non_reco_count[curr_pid] = 1 | |
| X = cluster_space_coords[node_counter : non + node_counter] | |
| if radius == "dynamic": | |
| pick_ = torch.argsort( | |
| torch.cdist(X[selected_centers], X[selected_centers], p=2), | |
| dim=1)[:, 1] | |
| current_radius = torch.cdist(torch.Tensor(X[selected_centers]), torch.Tensor(X[selected_centers]), p=2).gather(1, pick_.view(-1, 1)) | |
| current_radius = current_radius / 2 | |
| current_radius = max(0.1, current_radius.flatten().min()) | |
| print("Current radius", current_radius) | |
| else: | |
| print("Radius", radius) | |
| current_radius = radius | |
| clusterings = get_clustering(selected_centers, X, betas, td=current_radius) | |
| clusterings = clusterings.to(g.device) | |
| node_counter += non | |
| counter = 0 | |
| frac_energy = [] | |
| frac_energy_true = [] | |
| particle_ids = [] | |
| for alpha in selected_centers: | |
| id_particle = particle_id[alpha] | |
| true_mask_particle = particle_id == id_particle | |
| true_energy = torch.sum(g.ndata["e_hits"][true_mask_particle]) | |
| mask_clustering_particle = clusterings == counter | |
| clustered_energy = torch.sum(g.ndata["e_hits"][mask_clustering_particle]) | |
| clustered_energy_true = torch.sum( | |
| g.ndata["e_hits"][ | |
| mask_clustering_particle * true_mask_particle.flatten() | |
| ] | |
| ) # only consider how much has been correctly assigned | |
| counter += 1 | |
| frac_energy.append(clustered_energy / (true_energy + 1e-7)) | |
| frac_energy_true.append(clustered_energy_true / (true_energy + 1e-7)) | |
| particle_ids.append(id_particle.cpu().long().item()) | |
| frac_energy = torch.stack(frac_energy, dim=0) | |
| if not e_frac_loss_return_particles: | |
| frac_energy = torch.mean(frac_energy) | |
| frac_energy_true = torch.stack(frac_energy_true, dim=0) | |
| if not e_frac_loss_return_particles: | |
| frac_energy_true = torch.mean(frac_energy_true) | |
| loss_E_frac.append(frac_energy) | |
| loss_E_frac_true.append(frac_energy_true) | |
| particle_ids_all.append(particle_ids) | |
| if e_frac_loss_return_particles: | |
| return loss_E_frac, [loss_E_frac_true, particle_ids_all, reco_count, non_reco_count, total_count] | |
| loss_E_frac = torch.mean(torch.stack(loss_E_frac, dim=0)) | |
| loss_E_frac_true = torch.mean(torch.stack(loss_E_frac_true, dim=0)) | |
| return loss_E_frac, loss_E_frac_true | |
| def get_clustering(index_alpha_i, X, betas, td=0.7): | |
| n_points = betas.size(0) | |
| unassigned = torch.arange(n_points).to(betas.device) | |
| clustering = -1 * torch.ones(n_points, dtype=torch.long) | |
| counter = 0 | |
| for index_condpoint in index_alpha_i: | |
| d = torch.norm(X[unassigned] - X[index_condpoint], dim=-1) | |
| assigned_to_this_condpoint = unassigned[d < td] | |
| clustering[assigned_to_this_condpoint] = counter | |
| unassigned = unassigned[~(d < td)] | |
| counter = counter + 1 | |
| counter = 0 | |
| for index_condpoint in index_alpha_i: | |
| clustering[index_condpoint] = counter | |
| counter = counter + 1 | |
| return clustering | |