Spaces:
Sleeping
Sleeping
| from typing import Tuple, Union | |
| import numpy as np | |
| import torch | |
| from torch_scatter import scatter_max, scatter_add, scatter_mean | |
| from src.layers.loss_fill_space_torch import LLFillSpace | |
| def safe_index(arr, index): | |
| # One-hot index (or zero if it's not in the array) | |
| if index not in arr: | |
| return 0 | |
| else: | |
| return arr.index(index) + 1 | |
| def assert_no_nans(x): | |
| """ | |
| Raises AssertionError if there is a nan in the tensor | |
| """ | |
| if torch.isnan(x).any(): | |
| print(x) | |
| assert not torch.isnan(x).any() | |
| # FIXME: Use a logger instead of this | |
| DEBUG = False | |
| def debug(*args, **kwargs): | |
| if DEBUG: | |
| print(*args, **kwargs) | |
| def calc_LV_Lbeta( | |
| original_coords, | |
| g, | |
| distance_threshold, | |
| beta: torch.Tensor, | |
| cluster_space_coords: torch.Tensor, # Predicted by model | |
| cluster_index_per_event: torch.Tensor, # Truth hit->cluster index, e.g. [0, 1, 1, 0, 1, -1, 0, 1, 1] | |
| batch: torch.Tensor, # E.g. [0, 0, 0, 0, 1, 1, 1, 1, 1] | |
| # From here on just parameters | |
| qmin: float = 0.1, | |
| s_B: float = 1.0, | |
| noise_cluster_index: int = 0, # cluster_index entries with this value are noise/noise | |
| beta_stabilizing="soft_q_scaling", | |
| huberize_norm_for_V_attractive=False, | |
| beta_term_option="paper", | |
| frac_combinations=0, # fraction of the all possible pairs to be used for the clustering loss | |
| attr_weight=1.0, | |
| repul_weight=1.0, | |
| use_average_cc_pos=0.0, | |
| loss_type="hgcalimplementation", | |
| tracking=False, | |
| dis=False, | |
| beta_type="default", | |
| noise_logits=None, | |
| lorentz_norm=False, | |
| spatial_part_only=False, | |
| ) -> Union[Tuple[torch.Tensor, torch.Tensor], dict]: | |
| """ | |
| Calculates the L_V and L_beta object condensation losses. | |
| Concepts: | |
| - A hit belongs to exactly one cluster (cluster_index_per_event is (n_hits,)), | |
| and to exactly one event (batch is (n_hits,)) | |
| - A cluster index of `noise_cluster_index` means the cluster is a noise cluster. | |
| There is typically one noise cluster per event. Any hit in a noise cluster | |
| is a 'noise hit'. A hit in an object is called a 'signal hit' for lack of a | |
| better term. | |
| - An 'object' is a cluster that is *not* a noise cluster. | |
| beta_stabilizing: Choices are ['paper', 'clip', 'soft_q_scaling']: | |
| paper: beta is sigmoid(model_output), q = beta.arctanh()**2 + qmin | |
| clip: beta is clipped to 1-1e-4, q = beta.arctanh()**2 + qmin | |
| soft_q_scaling: beta is sigmoid(model_output), q = (clip(beta)/1.002).arctanh()**2 + qmin | |
| huberize_norm_for_V_attractive: Huberizes the norms when used in the attractive potential | |
| beta_term_option: Choices are ['paper', 'short-range-potential']: | |
| Choosing 'short-range-potential' introduces a short range potential around high | |
| beta points, acting like V_attractive. | |
| Note this function has modifications w.r.t. the implementation in 2002.03605: | |
| - The norms for V_repulsive are now Gaussian (instead of linear hinge) | |
| - Noise_logits: If set to an array, it is the output of the noise classifier (whether a particle belongs to a jet or not) | |
| """ | |
| # remove dummy rows added for dataloader #TODO think of better way to do this | |
| device = beta.device | |
| if torch.isnan(beta).any(): | |
| print("There are nans in beta! L198", len(beta[torch.isnan(beta)])) | |
| beta = torch.nan_to_num(beta, nan=0.0) | |
| assert_no_nans(beta) | |
| # ________________________________ | |
| # Calculate a bunch of needed counts and indices locally | |
| # cluster_index: unique index over events | |
| # E.g. cluster_index_per_event=[ 0, 0, 1, 2, 0, 0, 1], batch=[0, 0, 0, 0, 1, 1, 1] | |
| # -> cluster_index=[ 0, 0, 1, 2, 3, 3, 4 ] | |
| cluster_index, n_clusters_per_event = batch_cluster_indices( | |
| cluster_index_per_event, batch | |
| ) | |
| n_clusters = n_clusters_per_event.sum() | |
| n_hits, cluster_space_dim = cluster_space_coords.size() | |
| batch_size = batch.max().detach().cpu().item() + 1 | |
| n_hits_per_event = scatter_count(batch) | |
| # Index of cluster -> event (n_clusters,) | |
| batch_cluster = scatter_counts_to_indices(n_clusters_per_event) | |
| # Per-hit boolean, indicating whether hit is sig or noise | |
| is_noise = cluster_index_per_event == noise_cluster_index | |
| is_sig = ~is_noise | |
| n_hits_sig = is_sig.sum() | |
| n_sig_hits_per_event = scatter_count(batch[is_sig]) | |
| # Per-cluster boolean, indicating whether cluster is an object or noise | |
| is_object = scatter_max(is_sig.long(), cluster_index)[0].bool() | |
| is_noise_cluster = ~is_object | |
| # FIXME: This assumes noise_cluster_index == 0!! | |
| # Not sure how to do this in a performant way in case noise_cluster_index != 0 | |
| if noise_cluster_index != 0: | |
| raise NotImplementedError | |
| object_index_per_event = cluster_index_per_event[is_sig] - 1 | |
| object_index, n_objects_per_event = batch_cluster_indices( | |
| object_index_per_event, batch[is_sig] | |
| ) | |
| n_hits_per_object = scatter_count(object_index) | |
| # print("n_hits_per_object", n_hits_per_object) | |
| batch_object = batch_cluster[is_object] | |
| n_objects = is_object.sum() | |
| assert object_index.size() == (n_hits_sig,) | |
| assert is_object.size() == (n_clusters,) | |
| assert torch.all(n_hits_per_object > 0) | |
| assert object_index.max() + 1 == n_objects | |
| # ________________________________ | |
| # L_V term | |
| # Calculate q | |
| if beta_type == "default": | |
| if loss_type == "hgcalimplementation" or loss_type == "vrepweighted": | |
| q = (beta.clip(0.0, 1 - 1e-4).arctanh() / 1.01) ** 2 + qmin | |
| elif beta_stabilizing == "paper": | |
| q = beta.arctanh() ** 2 + qmin | |
| elif beta_stabilizing == "clip": | |
| beta = beta.clip(0.0, 1 - 1e-4) | |
| q = beta.arctanh() ** 2 + qmin | |
| elif beta_stabilizing == "soft_q_scaling": | |
| q = (beta.clip(0.0, 1 - 1e-4) / 1.002).arctanh() ** 2 + qmin | |
| else: | |
| raise ValueError(f"beta_stablizing mode {beta_stabilizing} is not known") | |
| elif beta_type == "pt": | |
| q = beta | |
| elif beta_type == "pt+bc": | |
| q = beta | |
| #if beta_type in ["pt", "pt+bc"]: | |
| # q[q<0.5] = 0.5 # cap the q | |
| assert_no_nans(q) | |
| assert q.device == device | |
| assert q.size() == (n_hits,) | |
| # Calculate q_alpha, the max q per object, and the indices of said maxima | |
| # assert hit_energies.shape == q.shape | |
| # q_alpha, index_alpha = scatter_max(hit_energies[is_sig], object_index) | |
| q_alpha, index_alpha = scatter_max(q[is_sig], object_index) | |
| assert q_alpha.size() == (n_objects,) | |
| # Get the cluster space coordinates and betas for these maxima hits too | |
| x_alpha = cluster_space_coords[is_sig][index_alpha] | |
| #x_alpha_original = original_coords[is_sig][index_alpha] | |
| if use_average_cc_pos > 0: | |
| #! this is a func of beta and q so maybe we could also do it with only q | |
| x_alpha_sum = scatter_add( | |
| q[is_sig].view(-1, 1).repeat(1, 3) * cluster_space_coords[is_sig], | |
| object_index, | |
| dim=0, | |
| ) # * beta[is_sig].view(-1, 1).repeat(1, 3) | |
| qbeta_alpha_sum = scatter_add(q[is_sig], object_index) + 1e-9 # * beta[is_sig] | |
| div_fac = 1 / qbeta_alpha_sum | |
| div_fac = torch.nan_to_num(div_fac, nan=0) | |
| x_alpha_mean = torch.mul(x_alpha_sum, div_fac.view(-1, 1).repeat(1, 3)) | |
| x_alpha = use_average_cc_pos * x_alpha_mean + (1 - use_average_cc_pos) * x_alpha | |
| if dis: | |
| phi_sum = scatter_add( | |
| beta[is_sig].view(-1) * distance_threshold[is_sig].view(-1), | |
| object_index, | |
| dim=0, | |
| ) | |
| phi_alpha_sum = scatter_add(beta[is_sig].view(-1), object_index) + 1e-9 | |
| phi_alpha = phi_sum / phi_alpha_sum | |
| beta_alpha = beta[is_sig][index_alpha] | |
| assert x_alpha.size() == (n_objects, cluster_space_dim) | |
| assert beta_alpha.size() == (n_objects,) | |
| # Connectivity matrix from hit (row) -> cluster (column) | |
| # Index to matrix, e.g.: | |
| # [1, 3, 1, 0] --> [ | |
| # [0, 1, 0, 0], | |
| # [0, 0, 0, 1], | |
| # [0, 1, 0, 0], | |
| # [1, 0, 0, 0] | |
| # ] | |
| M = torch.nn.functional.one_hot(cluster_index).long() | |
| # Anti-connectivity matrix; be sure not to connect hits to clusters in different events! | |
| M_inv = get_inter_event_norms_mask(batch, n_clusters_per_event) - M | |
| # Throw away noise cluster columns; we never need them | |
| M = M[:, is_object] | |
| M_inv = M_inv[:, is_object] | |
| assert M.size() == (n_hits, n_objects) | |
| assert M_inv.size() == (n_hits, n_objects) | |
| # Calculate all norms | |
| # Warning: Should not be used without a mask! | |
| # Contains norms between hits and objects from different events | |
| # (n_hits, 1, cluster_space_dim) - (1, n_objects, cluster_space_dim) | |
| # gives (n_hits, n_objects, cluster_space_dim) | |
| norms = (cluster_space_coords.unsqueeze(1) - x_alpha.unsqueeze(0)).norm(dim=-1) | |
| assert norms.size() == (n_hits, n_objects) | |
| L_clusters = torch.tensor(0.0).to(device) | |
| if frac_combinations != 0: | |
| L_clusters = L_clusters_calc( | |
| batch, cluster_space_coords, cluster_index, frac_combinations, q | |
| ) | |
| # ------- | |
| # Attractive potential term | |
| # First get all the relevant norms: We only want norms of signal hits | |
| # w.r.t. the object they belong to, i.e. no noise hits and no noise clusters. | |
| # First select all norms of all signal hits w.r.t. all objects, mask out later | |
| if loss_type == "hgcalimplementation" or loss_type == "vrepweighted": | |
| # if dis: | |
| # N_k = torch.sum(M, dim=0) # number of hits per object | |
| # norms = torch.sum( | |
| # torch.square(cluster_space_coords.unsqueeze(1) - x_alpha.unsqueeze(0)), | |
| # dim=-1, | |
| # ) | |
| # norms_att = norms[is_sig] | |
| # norms_att = norms_att / (2 * phi_alpha.unsqueeze(0) ** 2 + 1e-6) | |
| # #! att func as in line 159 of object condensation | |
| # norms_att = torch.log( | |
| # torch.exp(torch.Tensor([1]).to(norms_att.device)) * norms_att + 1 | |
| # ) | |
| N_k = torch.sum(M, dim=0) # number of hits per object | |
| if lorentz_norm: | |
| diff = cluster_space_coords.unsqueeze(1) - x_alpha.unsqueeze(0) | |
| norms = diff[:, :, 0]**2 - torch.sum(diff[:, :, 1:] ** 2, dim=-1) | |
| norms = norms.abs() ## ??? Why is this needed? wrong convention? | |
| #print("Norms", norms[:15]) | |
| else: | |
| if spatial_part_only: | |
| norms = torch.sum( | |
| torch.square(cluster_space_coords[:, 1:4].unsqueeze(1) - x_alpha[:, 1:4].unsqueeze(0)), | |
| dim=-1, | |
| ) | |
| else: | |
| norms = torch.sum( | |
| torch.square(cluster_space_coords.unsqueeze(1) - x_alpha.unsqueeze(0)), | |
| dim=-1, | |
| ) # Take the norm squared | |
| norms_att = norms[is_sig] | |
| #! att func as in line 159 of object condensation | |
| norms_att = torch.log( | |
| torch.exp(torch.Tensor([1]).to(norms_att.device)) * norms_att / 2 + 1 | |
| ) | |
| elif huberize_norm_for_V_attractive: | |
| norms_att = norms[is_sig] | |
| # Huberized version (linear but times 4) | |
| # Be sure to not move 'off-diagonal' away from zero | |
| # (i.e. norms of hits w.r.t. clusters they do _not_ belong to) | |
| norms_att = huber(norms_att + 1e-5, 4.0) | |
| else: | |
| norms_att = norms[is_sig] | |
| # Paper version is simply norms squared (no need for mask) | |
| norms_att = norms_att**2 | |
| assert norms_att.size() == (n_hits_sig, n_objects) | |
| # Now apply the mask to keep only norms of signal hits w.r.t. to the object | |
| # they belong to | |
| norms_att *= M[is_sig] | |
| # Sum over hits, then sum per event, then divide by n_hits_per_event, then sum over events | |
| if loss_type == "hgcalimplementation": | |
| # Final potential term | |
| # (n_sig_hits, 1) * (1, n_objects) * (n_sig_hits, n_objects) | |
| # hit_type = (g.ndata["hit_type"][is_sig].view(-1)==3)*4+1 #weight 5 for hadronic hits, 1 for | |
| # tracks = g.ndata["hit_type"][is_sig]==1 | |
| # hit_type[tracks] = 250 | |
| # total_sum_hits_types = scatter_add(hit_type.view(-1), object_index) | |
| V_attractive = q[is_sig].unsqueeze(-1) * q_alpha.unsqueeze(0) * norms_att | |
| assert V_attractive.size() == (n_hits_sig, n_objects) | |
| #! each shower is account for separately | |
| V_attractive = V_attractive.sum(dim=0) # K objects | |
| #! divide by the number of accounted points | |
| V_attractive = V_attractive.view(-1) / (N_k.view(-1) + 1e-3) | |
| # V_attractive = V_attractive.view(-1) / (total_sum_hits_types.view(-1) + 1e-3) | |
| # L_V_attractive = torch.mean(V_attractive) | |
| ## multiply by a weight that depends on the energy of the shower: | |
| # print("e_hits", e_hits) | |
| # print("weight_att", weight_att) | |
| # L_V_attractive = torch.sum(V_attractive*weight_att) | |
| L_V_attractive = torch.mean(V_attractive) | |
| # L_V_attractive = L_V_attractive / torch.sum(weight_att) | |
| L_V_attractive_2 = torch.sum(V_attractive) | |
| elif loss_type == "vrepweighted": | |
| if tracking: | |
| # weight the vtx hits inside the shower | |
| V_attractive = ( | |
| g.ndata["weights"][is_sig].unsqueeze(-1) | |
| * q[is_sig].unsqueeze(-1) | |
| * q_alpha.unsqueeze(0) | |
| * norms_att | |
| ) | |
| assert V_attractive.size() == (n_hits_sig, n_objects) | |
| V_attractive = V_attractive.sum(dim=0) # K objects | |
| L_V_attractive = torch.mean(V_attractive.view(-1)) | |
| else: | |
| # # weight per hit per shower to compensate for ecal hcal unbalance in hadronic showers | |
| # ecal_hits = scatter_add( | |
| # 1 * (g.ndata["hit_type"][is_sig] == 2), object_index | |
| # ) | |
| # hcal_hits = scatter_add( | |
| # 1 * (g.ndata["hit_type"][is_sig] == 3), object_index | |
| # ) | |
| # weights = torch.ones_like(g.ndata["hit_type"][is_sig]) | |
| # weight_ecal_per_object = 1.0 * ecal_hits.clone() + 1 | |
| # weight_hcal_per_object = 1.0 * ecal_hits.clone() + 1 | |
| # mask = (ecal_hits > 2) * (hcal_hits > 2) | |
| # weight_ecal_per_object[mask] = (ecal_hits + hcal_hits)[mask] / ( | |
| # 2 * ecal_hits | |
| # )[mask] | |
| # weight_hcal_per_object[mask] = (ecal_hits + hcal_hits)[mask] / ( | |
| # 2 * hcal_hits | |
| # )[mask] | |
| # weights[g.ndata["hit_type"][is_sig] == 2] = weight_ecal_per_object[ | |
| # object_index | |
| # ] | |
| # weights[g.ndata["hit_type"][is_sig] == 3] = weight_hcal_per_object[ | |
| # object_index | |
| # ] | |
| # # weight with an energy log of the hits | |
| # e_hits = g.ndata["e_hits"][is_sig].view(-1) | |
| # p_hits = g.ndata["h"][:, -1][is_sig].view(-1) | |
| # log_scale_s = torch.log(e_hits + p_hits) + 10 | |
| # e_sum_hits = scatter_add(log_scale_s, object_index) | |
| # # need to take out the weight of alpha otherwise it won't add up to 1 | |
| # e_sum_hits = e_sum_hits - (log_scale_s[index_alpha]) | |
| # e_rel = (log_scale_s) / e_sum_hits[object_index] | |
| # weight of the hit depending on the radial distance: | |
| # this weight should help to seed | |
| # weight_radial_distance = torch.exp( | |
| # -g.ndata["radial_distance"][is_sig] / 100 | |
| # ) | |
| # weight_per_object = scatter_add(weight_radial_distance, object_index) | |
| # weight_radial_distance = ( | |
| # weight_radial_distance / weight_per_object[object_index] | |
| # ) | |
| V_attractive = ( | |
| q[is_sig].unsqueeze(-1) ## weight_radial_distance.unsqueeze(-1) | |
| * q_alpha.unsqueeze(0) | |
| * norms_att | |
| ) | |
| # weight modified showers with a higher weight | |
| modified_showers = scatter_max(g.ndata["hit_link_modified"], object_index)[ | |
| 0 | |
| ] | |
| n_modified = torch.sum(modified_showers) | |
| weight_modified = len(modified_showers) / (2 * n_modified) | |
| weight_unmodified = len(modified_showers) / ( | |
| 2 * (len(modified_showers) - n_modified) | |
| ) | |
| modified_showers[modified_showers > 0] = weight_modified | |
| modified_showers[modified_showers == 0] = weight_unmodified | |
| assert V_attractive.size() == (n_hits_sig, n_objects) | |
| V_attractive = V_attractive.sum(dim=0) # K objects | |
| L_V_attractive = torch.sum( | |
| modified_showers.view(-1) * V_attractive.view(-1) | |
| ) / len(modified_showers) | |
| else: | |
| # Final potential term | |
| # (n_sig_hits, 1) * (1, n_objects) * (n_sig_hits, n_objects) | |
| V_attractive = q[is_sig].unsqueeze(-1) * q_alpha.unsqueeze(0) * norms_att | |
| assert V_attractive.size() == (n_hits_sig, n_objects) | |
| #! in comparison this works per hit | |
| V_attractive = ( | |
| scatter_add(V_attractive.sum(dim=0), batch_object) / n_hits_per_event | |
| ) | |
| assert V_attractive.size() == (batch_size,) | |
| L_V_attractive = V_attractive.sum() | |
| # ------- | |
| # Repulsive potential term | |
| # Get all the relevant norms: We want norms of any hit w.r.t. to | |
| # objects they do *not* belong to, i.e. no noise clusters. | |
| # We do however want to keep norms of noise hits w.r.t. objects | |
| # Power-scale the norms: Gaussian scaling term instead of a cone | |
| # Mask out the norms of hits w.r.t. the cluster they belong to | |
| if loss_type == "hgcalimplementation" or loss_type == "vrepweighted": | |
| if dis: | |
| norms = norms / (2 * phi_alpha.unsqueeze(0) ** 2 + 1e-6) | |
| norms_rep = torch.exp(-(norms)) * M_inv | |
| norms_rep2 = torch.exp(-(norms) * 10) * M_inv | |
| else: | |
| norms_rep = torch.exp(-(norms) / 2) * M_inv | |
| # norms_rep2 = torch.exp(-(norms) * 10) * M_inv | |
| norms_rep2 = torch.exp(-(norms) * 10) * M_inv | |
| else: | |
| norms_rep = torch.exp(-4.0 * norms**2) * M_inv | |
| # (n_sig_hits, 1) * (1, n_objects) * (n_sig_hits, n_objects) | |
| V_repulsive = q.unsqueeze(1) * q_alpha.unsqueeze(0) * norms_rep | |
| # No need to apply a V = max(0, V); by construction V>=0 | |
| assert V_repulsive.size() == (n_hits, n_objects) | |
| # Sum over hits, then sum per event, then divide by n_hits_per_event, then sum up events | |
| nope = n_objects_per_event - 1 | |
| nope[nope == 0] = 1 | |
| if loss_type == "hgcalimplementation" or loss_type == "vrepweighted": | |
| #! sum each object repulsive terms | |
| L_V_repulsive = V_repulsive.sum(dim=0) # size number of objects | |
| number_of_repulsive_terms_per_object = torch.sum(M_inv, dim=0) | |
| L_V_repulsive = L_V_repulsive.view( | |
| -1 | |
| ) / number_of_repulsive_terms_per_object.view(-1) | |
| V_repulsive2 = q.unsqueeze(1) * q_alpha.unsqueeze(0) * norms_rep2 | |
| L_V_repulsive2 = V_repulsive2.sum(dim=0) # size number of objects | |
| L_V_repulsive2 = L_V_repulsive2.view(-1) | |
| L_V_attractive_2 = L_V_attractive_2.view(-1) | |
| # if not tracking: | |
| # #! add to terms function (divide by total number of showers per event) | |
| # # L_V_repulsive = scatter_add(L_V_repulsive, object_index) / n_objects | |
| # per_shower_weight = torch.exp(1 / (e_particles_pred_per_object + 0.4)) | |
| # soft_m = torch.nn.Softmax(dim=0) | |
| # per_shower_weight = soft_m(per_shower_weight) * len(L_V_repulsive) | |
| # L_V_repulsive = torch.mean(L_V_repulsive * per_shower_weight) | |
| # else: | |
| # if tracking: | |
| # L_V_repulsive = torch.mean(L_V_repulsive * per_shower_weight) | |
| # else: | |
| if loss_type == "vrepweighted": | |
| L_V_repulsive = torch.sum( | |
| modified_showers.view(-1) * L_V_repulsive.view(-1) | |
| ) / len(modified_showers) | |
| L_V_repulsive2 = torch.sum( | |
| modified_showers.view(-1) * L_V_repulsive2.view(-1) | |
| ) / len(modified_showers) | |
| else: | |
| L_V_repulsive = torch.mean(L_V_repulsive) | |
| L_V_repulsive2 = torch.mean(L_V_repulsive2) | |
| else: | |
| L_V_repulsive = ( | |
| scatter_add(V_repulsive.sum(dim=0), batch_object) | |
| / (n_hits_per_event * nope) | |
| ).sum() | |
| L_V = ( | |
| attr_weight * L_V_attractive + repul_weight * L_V_repulsive | |
| ) | |
| n_noise_hits_per_event = scatter_count(batch[is_noise]) | |
| n_noise_hits_per_event[n_noise_hits_per_event == 0] = 1 | |
| L_beta_noise = ( | |
| s_B | |
| * ( | |
| (scatter_add(beta[is_noise], batch[is_noise])) / n_noise_hits_per_event | |
| ).sum() | |
| ) | |
| if loss_type == "hgcalimplementation": | |
| beta_per_object_c = scatter_add(beta[is_sig], object_index) | |
| beta_alpha = beta[is_sig][index_alpha] | |
| L_beta_sig = torch.mean( | |
| 1 - beta_alpha + 1 - torch.clip(beta_per_object_c, 0, 1) | |
| ) | |
| L_beta_noise = L_beta_noise / 4 | |
| # ? note: the training that worked quite well was dividing this by the batch size (1/4) | |
| elif loss_type == "vrepweighted": | |
| # version one: | |
| beta_per_object_c = scatter_add(beta[is_sig], object_index) | |
| beta_alpha = beta[is_sig][index_alpha] | |
| L_beta_sig = 1 - beta_alpha + 1 - torch.clip(beta_per_object_c, 0, 1) | |
| L_beta_sig = torch.sum(L_beta_sig.view(-1) * modified_showers.view(-1)) | |
| L_beta_sig = L_beta_sig / len(modified_showers) | |
| L_beta_noise = L_beta_noise / batch_size | |
| # ? note: the training that worked quite well was dividing this by the batch size (1/4) | |
| elif beta_term_option == "paper": | |
| beta_alpha = beta[is_sig][index_alpha] | |
| L_beta_sig = torch.sum( # maybe 0.5 for less aggressive loss | |
| scatter_add((1 - beta_alpha), batch_object) / n_objects_per_event | |
| ) | |
| # print("L_beta_sig", L_beta_sig / batch_size) | |
| # beta_exp = beta[is_sig] | |
| # beta_exp[index_alpha] = 0 | |
| # # L_exp = torch.mean(beta_exp) | |
| # beta_exp = torch.exp(0.5 * beta_exp) | |
| # L_exp = torch.mean(scatter_add(beta_exp, batch) / n_hits_per_event) | |
| elif beta_term_option == "short-range-potential": | |
| # First collect the norms: We only want norms of hits w.r.t. the object they | |
| # belong to (like in V_attractive) | |
| # Apply transformation first, and then apply mask to keep only the norms we want, | |
| # then sum over hits, so the result is (n_objects,) | |
| norms_beta_sig = (1.0 / (20.0 * norms[is_sig] ** 2 + 1.0) * M[is_sig]).sum( | |
| dim=0 | |
| ) | |
| assert torch.all(norms_beta_sig >= 1.0) and torch.all( | |
| norms_beta_sig <= n_hits_per_object | |
| ) | |
| # Subtract from 1. to remove self interaction, divide by number of hits per object | |
| norms_beta_sig = (1.0 - norms_beta_sig) / n_hits_per_object | |
| assert torch.all(norms_beta_sig >= -1.0) and torch.all(norms_beta_sig <= 0.0) | |
| norms_beta_sig *= beta_alpha | |
| # Conclusion: | |
| # lower beta --> higher loss (less negative) | |
| # higher norms --> higher loss | |
| # Sum over objects, divide by number of objects per event, then sum over events | |
| L_beta_norms_term = ( | |
| scatter_add(norms_beta_sig, batch_object) / n_objects_per_event | |
| ).sum() | |
| assert L_beta_norms_term >= -batch_size and L_beta_norms_term <= 0.0 | |
| # Logbeta term: Take -.2*torch.log(beta_alpha[is_object]+1e-9), sum it over objects, | |
| # divide by n_objects_per_event, then sum over events (same pattern as above) | |
| # lower beta --> higher loss | |
| L_beta_logbeta_term = ( | |
| scatter_add(-0.2 * torch.log(beta_alpha + 1e-9), batch_object) | |
| / n_objects_per_event | |
| ).sum() | |
| # Final L_beta term | |
| L_beta_sig = L_beta_norms_term + L_beta_logbeta_term | |
| else: | |
| valid_options = ["paper", "short-range-potential"] | |
| raise ValueError( | |
| f'beta_term_option "{beta_term_option}" is not valid, choose from {valid_options}' | |
| ) | |
| L_beta = L_beta_noise + L_beta_sig | |
| if beta_type == "pt" or beta_type == "pt+bc": | |
| L_beta = torch.tensor(0.) | |
| L_beta_sig = torch.tensor(0.) | |
| L_beta_noise = torch.tensor(0.) | |
| #L_alpha_coordinates = torch.mean(torch.norm(x_alpha_original - x_alpha, p=2, dim=1)) | |
| x_original = original_coords / torch.norm(original_coords, p=2, dim=1).view(-1, 1) | |
| x_virtual = cluster_space_coords / torch.norm(cluster_space_coords, p=2, dim=1).view(-1, 1) | |
| loss_coord = torch.mean(torch.norm(x_original - x_virtual, p=2, dim=1)) # We just compare the direction | |
| if beta_type == "pt+bc": | |
| assert noise_logits is not None | |
| y_true_noise = 1 - is_noise.float() | |
| num_positives = torch.sum(y_true_noise).item() | |
| num_negatives = len(y_true_noise) - num_positives | |
| num_all = len(y_true_noise) | |
| # Compute weights | |
| pos_weight = num_all / num_positives if num_positives > 0 else 0 | |
| neg_weight = num_all / num_negatives if num_negatives > 0 else 0 | |
| weight = pos_weight * y_true_noise + neg_weight * (1 - y_true_noise) | |
| L_bc = torch.nn.BCELoss(weight=weight)( | |
| noise_logits, 1-is_noise.float() | |
| ) | |
| #if torch.isnan(L_beta / batch_size): | |
| # print("isnan!!!") | |
| # print(L_beta, batch_size) | |
| # print("L_beta_noise", L_beta_noise) | |
| # print("L_beta_sig", L_beta_sig) | |
| result = { | |
| "loss_potential": L_V, # 0 | |
| "loss_beta": L_beta, | |
| "loss_beta_sig": L_beta_sig, # signal part of the betas | |
| "loss_beta_noise": L_beta_noise, # noise part of the betas | |
| "loss_attractive": L_V_attractive, | |
| "loss_repulsive": L_V_repulsive, | |
| "loss_coord": loss_coord, | |
| } | |
| if beta_type == "pt+bc": | |
| result["loss_noise_classification"] = L_bc | |
| return result | |
| def huber(d, delta): | |
| """ | |
| See: https://en.wikipedia.org/wiki/Huber_loss#Definition | |
| Multiplied by 2 w.r.t Wikipedia version (aligning with Jan's definition) | |
| """ | |
| return torch.where( | |
| torch.abs(d) <= delta, d**2, 2.0 * delta * (torch.abs(d) - delta) | |
| ) | |
| def batch_cluster_indices( | |
| cluster_id: torch.Tensor, batch: torch.Tensor | |
| ) -> Tuple[torch.LongTensor, torch.LongTensor]: | |
| """ | |
| Turns cluster indices per event to an index in the whole batch | |
| Example: | |
| cluster_id = torch.LongTensor([0, 0, 1, 1, 2, 0, 0, 1, 1, 1, 0, 0, 1]) | |
| batch = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2]) | |
| --> | |
| offset = torch.LongTensor([0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 5, 5, 5]) | |
| output = torch.LongTensor([0, 0, 1, 1, 2, 3, 3, 4, 4, 4, 5, 5, 6]) | |
| """ | |
| device = cluster_id.device | |
| assert cluster_id.device == batch.device | |
| # Count the number of clusters per entry in the batch | |
| n_clusters_per_event = scatter_max(cluster_id, batch, dim=-1)[0] + 1 | |
| # Offsets are then a cumulative sum | |
| offset_values_nozero = n_clusters_per_event[:-1].cumsum(dim=-1) | |
| # Prefix a zero | |
| offset_values = torch.cat((torch.zeros(1, device=device), offset_values_nozero)) | |
| # Fill it per hit | |
| offset = torch.gather(offset_values, 0, batch).long() | |
| return offset + cluster_id, n_clusters_per_event | |
| def get_clustering_np( | |
| betas: np.array, X: np.array, tbeta: float = 0.1, td: float = 1.0 | |
| ) -> np.array: | |
| """ | |
| Returns a clustering of hits -> cluster_index, based on the GravNet model | |
| output (predicted betas and cluster space coordinates) and the clustering | |
| parameters tbeta and td. | |
| Takes numpy arrays as input. | |
| """ | |
| n_points = betas.shape[0] | |
| select_condpoints = betas > tbeta | |
| # Get indices passing the threshold | |
| indices_condpoints = np.nonzero(select_condpoints)[0] | |
| # Order them by decreasing beta value | |
| indices_condpoints = indices_condpoints[np.argsort(-betas[select_condpoints])] | |
| # Assign points to condensation points | |
| # Only assign previously unassigned points (no overwriting) | |
| # Points unassigned at the end are bkg (-1) | |
| unassigned = np.arange(n_points) | |
| clustering = -1 * np.ones(n_points, dtype=np.int32) | |
| for index_condpoint in indices_condpoints: | |
| d = np.linalg.norm(X[unassigned] - X[index_condpoint], axis=-1) | |
| assigned_to_this_condpoint = unassigned[d < td] | |
| clustering[assigned_to_this_condpoint] = index_condpoint | |
| unassigned = unassigned[~(d < td)] | |
| return clustering | |
| def get_clustering(betas: torch.Tensor, X: torch.Tensor, tbeta=0.1, td=1.0): | |
| """ | |
| Returns a clustering of hits -> cluster_index, based on the GravNet model | |
| output (predicted betas and cluster space coordinates) and the clustering | |
| parameters tbeta and td. | |
| Takes torch.Tensors as input. | |
| """ | |
| n_points = betas.size(0) | |
| select_condpoints = betas > tbeta | |
| # Get indices passing the threshold | |
| indices_condpoints = select_condpoints.nonzero() | |
| # Order them by decreasing beta value | |
| indices_condpoints = indices_condpoints[(-betas[select_condpoints]).argsort()] | |
| # Assign points to condensation points | |
| # Only assign previously unassigned points (no overwriting) | |
| # Points unassigned at the end are bkg (-1) | |
| unassigned = torch.arange(n_points) | |
| clustering = -1 * torch.ones(n_points, dtype=torch.long) | |
| for index_condpoint in indices_condpoints: | |
| d = torch.norm(X[unassigned] - X[index_condpoint][0], dim=-1) | |
| assigned_to_this_condpoint = unassigned[d < td] | |
| clustering[assigned_to_this_condpoint] = index_condpoint[0] | |
| unassigned = unassigned[~(d < td)] | |
| return clustering | |
| def scatter_count(input: torch.Tensor): | |
| """ | |
| Returns ordered counts over an index array | |
| Example: | |
| >>> scatter_count(torch.Tensor([0, 0, 0, 1, 1, 2, 2])) # input | |
| >>> [3, 2, 2] | |
| Index assumptions work like in torch_scatter, so: | |
| >>> scatter_count(torch.Tensor([1, 1, 1, 2, 2, 4, 4])) | |
| >>> tensor([0, 3, 2, 0, 2]) | |
| """ | |
| return scatter_add(torch.ones_like(input, dtype=torch.long), input.long()) | |
| def scatter_counts_to_indices(input: torch.LongTensor) -> torch.LongTensor: | |
| """ | |
| Converts counts to indices. This is the inverse operation of scatter_count | |
| Example: | |
| input: [3, 2, 2] | |
| output: [0, 0, 0, 1, 1, 2, 2] | |
| """ | |
| return torch.repeat_interleave( | |
| torch.arange(input.size(0), device=input.device), input | |
| ).long() | |
| def get_inter_event_norms_mask( | |
| batch: torch.LongTensor, nclusters_per_event: torch.LongTensor | |
| ): | |
| """ | |
| Creates mask of (nhits x nclusters) that is only 1 if hit i is in the same event as cluster j | |
| Example: | |
| cluster_id_per_event = torch.LongTensor([0, 0, 1, 1, 2, 0, 0, 1, 1, 1, 0, 0, 1]) | |
| batch = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2]) | |
| Should return: | |
| torch.LongTensor([ | |
| [1, 1, 1, 0, 0, 0, 0], | |
| [1, 1, 1, 0, 0, 0, 0], | |
| [1, 1, 1, 0, 0, 0, 0], | |
| [1, 1, 1, 0, 0, 0, 0], | |
| [1, 1, 1, 0, 0, 0, 0], | |
| [0, 0, 0, 1, 1, 0, 0], | |
| [0, 0, 0, 1, 1, 0, 0], | |
| [0, 0, 0, 1, 1, 0, 0], | |
| [0, 0, 0, 1, 1, 0, 0], | |
| [0, 0, 0, 1, 1, 0, 0], | |
| [0, 0, 0, 0, 0, 1, 1], | |
| [0, 0, 0, 0, 0, 1, 1], | |
| [0, 0, 0, 0, 0, 1, 1], | |
| ]) | |
| """ | |
| device = batch.device | |
| # Following the example: | |
| # Expand batch to the following (nhits x nevents) matrix (little hacky, boolean mask -> long): | |
| # [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |
| # [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0], | |
| # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]] | |
| batch_expanded_as_ones = ( | |
| batch | |
| == torch.arange(batch.max() + 1, dtype=torch.long, device=device).unsqueeze(-1) | |
| ).long() | |
| # Then repeat_interleave it to expand it to nclusters rows, and transpose to get (nhits x nclusters) | |
| return batch_expanded_as_ones.repeat_interleave(nclusters_per_event, dim=0).T | |
| def isin(ar1, ar2): | |
| """To be replaced by torch.isin for newer releases of torch""" | |
| return (ar1[..., None] == ar2).any(-1) | |
| def reincrementalize(y: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: | |
| """Re-indexes y so that missing clusters are no longer counted. | |
| Example: | |
| >>> y = torch.LongTensor([ | |
| 0, 0, 0, 1, 1, 3, 3, | |
| 0, 0, 0, 0, 0, 2, 2, 3, 3, | |
| 0, 0, 1, 1 | |
| ]) | |
| >>> batch = torch.LongTensor([ | |
| 0, 0, 0, 0, 0, 0, 0, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 2, 2, 2, 2, | |
| ]) | |
| >>> print(reincrementalize(y, batch)) | |
| tensor([0, 0, 0, 1, 1, 2, 2, 0, 0, 0, 0, 0, 1, 1, 2, 2, 0, 0, 1, 1]) | |
| """ | |
| y_offset, n_per_event = batch_cluster_indices(y, batch) | |
| offset = y_offset - y | |
| n_clusters = n_per_event.sum() | |
| holes = ( | |
| (~isin(torch.arange(n_clusters, device=y.device), y_offset)) | |
| .nonzero() | |
| .squeeze(-1) | |
| ) | |
| n_per_event_without_holes = n_per_event.clone() | |
| n_per_event_cumsum = n_per_event.cumsum(0) | |
| for hole in holes.sort(descending=True).values: | |
| y_offset[y_offset > hole] -= 1 | |
| i_event = (hole > n_per_event_cumsum).long().argmin() | |
| n_per_event_without_holes[i_event] -= 1 | |
| offset_per_event = torch.zeros_like(n_per_event_without_holes) | |
| offset_per_event[1:] = n_per_event_without_holes.cumsum(0)[:-1] | |
| offset_without_holes = torch.gather(offset_per_event, 0, batch).long() | |
| reincrementalized = y_offset - offset_without_holes | |
| return reincrementalized | |
| def L_clusters_calc(batch, cluster_space_coords, cluster_index, frac_combinations, q): | |
| number_of_pairs = 0 | |
| for batch_id in batch.unique(): | |
| # do all possible pairs... | |
| bmask = batch == batch_id | |
| clust_space_filt = cluster_space_coords[bmask] | |
| pos_pairs_all = [] | |
| neg_pairs_all = [] | |
| if len(cluster_index[bmask].unique()) <= 1: | |
| continue | |
| L_clusters = torch.tensor(0.0).to(q.device) | |
| for cluster in cluster_index[bmask].unique(): | |
| coords_pos = clust_space_filt[cluster_index[bmask] == cluster] | |
| coords_neg = clust_space_filt[cluster_index[bmask] != cluster] | |
| if len(coords_neg) == 0: | |
| continue | |
| clust_idx = cluster_index[bmask] == cluster | |
| # all_ones = torch.ones_like((clust_idx, clust_idx)) | |
| # pos_pairs = [[i, j] for i in range(len(coords_pos)) for j in range (len(coords_pos)) if i < j] | |
| total_num = (len(coords_pos) ** 2) / 2 | |
| num = int(frac_combinations * total_num) | |
| pos_pairs = [] | |
| for i in range(num): | |
| pos_pairs.append( | |
| [ | |
| np.random.randint(len(coords_pos)), | |
| np.random.randint(len(coords_pos)), | |
| ] | |
| ) | |
| neg_pairs = [] | |
| for i in range(len(pos_pairs)): | |
| neg_pairs.append( | |
| [ | |
| np.random.randint(len(coords_pos)), | |
| np.random.randint(len(coords_neg)), | |
| ] | |
| ) | |
| pos_pairs_all += pos_pairs | |
| neg_pairs_all += neg_pairs | |
| pos_pairs = torch.tensor(pos_pairs_all) | |
| neg_pairs = torch.tensor(neg_pairs_all) | |
| """# do just a small sample of the pairs. ... | |
| bmask = batch == batch_id | |
| #L_clusters = 0 # Loss of randomly sampled distances between points inside and outside clusters | |
| pos_idx, neg_idx = [], [] | |
| for cluster in cluster_index[bmask].unique(): | |
| clust_idx = (cluster_index == cluster)[bmask] | |
| perm = torch.randperm(clust_idx.sum()) | |
| perm1 = torch.randperm((~clust_idx).sum()) | |
| perm2 = torch.randperm(clust_idx.sum()) | |
| #cutoff = clust_idx.sum()//2 | |
| pos_lst = clust_idx.nonzero()[perm] | |
| neg_lst = (~clust_idx).nonzero()[perm1] | |
| neg_lst_second = clust_idx.nonzero()[perm2] | |
| if len(pos_lst) % 2: | |
| pos_lst = pos_lst[:-1] | |
| if len(neg_lst) % 2: | |
| neg_lst = neg_lst[:-1] | |
| len_cap = min(len(pos_lst), len(neg_lst), len(neg_lst_second)) | |
| if len_cap % 2: | |
| len_cap -= 1 | |
| pos_lst = pos_lst[:len_cap] | |
| neg_lst = neg_lst[:len_cap] | |
| neg_lst_second = neg_lst_second[:len_cap] | |
| pos_pairs = pos_lst.reshape(-1, 2) | |
| neg_pairs = torch.cat([neg_lst, neg_lst_second], dim=1) | |
| neg_pairs = neg_pairs[:pos_lst.shape[0]//2, :] | |
| pos_idx.append(pos_pairs) | |
| neg_idx.append(neg_pairs) | |
| pos_idx = torch.cat(pos_idx) | |
| neg_idx = torch.cat(neg_idx)""" | |
| assert pos_pairs.shape == neg_pairs.shape | |
| if len(pos_pairs) == 0: | |
| continue | |
| cluster_space_coords_filtered = cluster_space_coords[bmask] | |
| qs_filtered = q[bmask] | |
| pos_norms = ( | |
| cluster_space_coords_filtered[pos_pairs[:, 0]] | |
| - cluster_space_coords_filtered[pos_pairs[:, 1]] | |
| ).norm(dim=-1) | |
| neg_norms = ( | |
| cluster_space_coords_filtered[neg_pairs[:, 0]] | |
| - cluster_space_coords_filtered[neg_pairs[:, 1]] | |
| ).norm(dim=-1) | |
| q_pos = qs_filtered[pos_pairs[:, 0]] | |
| q_neg = qs_filtered[neg_pairs[:, 0]] | |
| q_s = torch.cat([q_pos, q_neg]) | |
| norms_pos = torch.cat([pos_norms, neg_norms]) | |
| ys = torch.cat([torch.ones_like(pos_norms), -torch.ones_like(neg_norms)]) | |
| L_clusters += torch.sum( | |
| q_s * torch.nn.HingeEmbeddingLoss(reduce=None)(norms_pos, ys) | |
| ) | |
| number_of_pairs += norms_pos.shape[0] | |
| if number_of_pairs > 0: | |
| L_clusters = L_clusters / number_of_pairs | |
| return L_clusters | |
| def calc_eta_phi(coords, return_stacked=True): | |
| """ | |
| Calculate eta and phi from cartesian coordinates | |
| """ | |
| x = coords[:, 0] | |
| y = coords[:, 1] | |
| z = coords[:, 2] | |
| #eta, phi = torch.atan2(y, x), torch.asin(z / coords.norm(dim=1)) | |
| phi = torch.arctan2(y, x) | |
| eta = torch.arctanh(z / torch.sqrt(x**2 + y**2 + z**2)) | |
| if not return_stacked: | |
| return eta, phi | |
| return torch.stack([eta, phi], dim=1) | |
| def loss_func_aug(y_pred, y_pred_aug, batch, batch_aug, event, event_aug): | |
| coords_pred = y_pred[:, :3] | |
| coords_pred_aug = y_pred_aug[:, :3] | |
| original_particle_mapping = batch_aug.original_particle_mapping | |
| #print("N in batch:", event.pfcands.batch_number) | |
| #print("N in batch aug:", event_aug.pfcands.batch_number) | |
| to_add_to_batch = event.pfcands.batch_number[:-1] | |
| aug_batch_num = event_aug.pfcands.batch_number | |
| print("Original particle mapping: (before sum)", original_particle_mapping.tolist()) | |
| filt_idx = torch.where(original_particle_mapping != -1)[0].tolist() | |
| for i in range(len(aug_batch_num)-1): | |
| for item in filt_idx: | |
| if item >= aug_batch_num[i] and item < aug_batch_num[i+1]: | |
| assert original_particle_mapping[item] != -1, "Original particle mapping should not be -1" | |
| assert to_add_to_batch[i] >= 0, "Batch number should be >= 0: " + str(to_add_to_batch[i]) | |
| original_particle_mapping[item] += to_add_to_batch[i] # Try this due to some indexing issues | |
| #original_particle_mapping[aug_batch_num[i]:aug_batch_num[i+1]][filt] += to_add_to_batch[i] | |
| #print("Original particle mapping:", original_particle_mapping[original_particle_mapping != -1]) | |
| #original_particle_mapping[original_particle_mapping != -1] += batch_idx[original_particle_mapping != -1] | |
| if not original_particle_mapping.max() < len(coords_pred): | |
| print("Coords shapes", coords_pred.shape, coords_pred_aug.shape) | |
| print("Original particle mapping:", original_particle_mapping[original_particle_mapping != -1], original_particle_mapping.shape, original_particle_mapping[original_particle_mapping!=-1].max()) | |
| print("Batch number in event:", event.pfcands.batch_number) | |
| print("Batch number in event aug:", event_aug.pfcands.batch_number) | |
| print("Len batch", batch.input_vectors.shape, "len batch_aug", batch_aug.input_vectors.shape) | |
| raise ValueError("Original particle mapping out of bounds") | |
| assert original_particle_mapping.max() < len(coords_pred) | |
| coords_pred_aug_target = coords_pred[original_particle_mapping[original_particle_mapping != -1]] | |
| coords_pred_aug_output = coords_pred_aug[original_particle_mapping != -1] | |
| print("Output:", coords_pred_aug_output[:5], "Target:", coords_pred_aug_target[:5]) | |
| loss = torch.nn.MSELoss()(coords_pred_aug_output, coords_pred_aug_target) | |
| return loss | |
| def object_condensation_loss( | |
| batch, # input event | |
| pred, | |
| labels, | |
| batch_numbers, | |
| q_min=3.0, | |
| frac_clustering_loss=0.1, | |
| attr_weight=1.0, | |
| repul_weight=1.0, | |
| fill_loss_weight=1.0, | |
| use_average_cc_pos=0.0, | |
| loss_type="hgcalimplementation", | |
| clust_space_norm="none", | |
| dis=False, | |
| coord_weight=0.0, | |
| beta_type="default", | |
| lorentz_norm=False, | |
| spatial_part_only=False, | |
| loss_quark_distance=False, | |
| oc_scalars=False, | |
| loss_obj_score=False | |
| ): | |
| """ | |
| :param batch: Model input | |
| :param pred: Model output, containing regressed coordinates + betas | |
| :param clust_space_dim: Number of dimensions in the cluster space | |
| :return: | |
| """ | |
| _, S = pred.shape | |
| noise_logits = None | |
| if beta_type == "default": | |
| clust_space_dim = S - 1 | |
| bj = torch.sigmoid(torch.reshape(pred[:, clust_space_dim], [-1, 1])) # betas | |
| elif beta_type == "pt": | |
| bj = batch.pt | |
| clust_space_dim = S | |
| elif beta_type == "pt+bc": | |
| bj = batch.pt | |
| clust_space_dim = S - 1 | |
| noise_logits = pred[:, clust_space_dim] | |
| original_coords = batch.input_vectors | |
| if oc_scalars: | |
| original_coords = original_coords[:, 1:4] | |
| if dis: | |
| distance_threshold = torch.reshape(pred[:, -1], [-1, 1]) | |
| else: | |
| distance_threshold = 0 | |
| xj = pred[:, :clust_space_dim] # Coordinates in clustering space | |
| #xj = calc_eta_phi(xj) | |
| if clust_space_norm == "twonorm": | |
| xj = torch.nn.functional.normalize(xj, dim=1) | |
| elif clust_space_norm == "tanh": | |
| xj = torch.tanh(xj) | |
| elif clust_space_norm == "none": | |
| pass | |
| else: | |
| raise NotImplementedError | |
| if not loss_quark_distance: | |
| clustering_index_l = labels | |
| if loss_obj_score: | |
| clustering_index_l = labels.labels+1 | |
| a = calc_LV_Lbeta( | |
| original_coords, | |
| batch, | |
| distance_threshold, | |
| beta=bj.view(-1), | |
| cluster_space_coords=xj, # Predicted by model | |
| cluster_index_per_event=clustering_index_l.view( | |
| -1 | |
| ).long(), # Truth hit->cluster index | |
| batch=batch_numbers.long(), | |
| qmin=q_min, | |
| attr_weight=attr_weight, | |
| repul_weight=repul_weight, | |
| use_average_cc_pos=use_average_cc_pos, | |
| loss_type=loss_type, | |
| dis=dis, | |
| beta_type=beta_type, | |
| noise_logits=noise_logits, | |
| lorentz_norm=lorentz_norm, | |
| spatial_part_only=spatial_part_only | |
| ) | |
| loss = a["loss_potential"] + a["loss_beta"] | |
| if coord_weight > 0: | |
| loss += a["loss_coord"] * coord_weight | |
| else: | |
| # quark distance loss | |
| target_coords = labels.labels_coordinates[labels.labels[labels.labels != -1]] | |
| if lorentz_norm: | |
| diff = xj[labels.labels != -1] - labels.labels_coordinates[labels.labels != -1] | |
| norms = diff[:, :, 0]**2 - torch.sum(diff[:, :, 1:] ** 2, dim=-1) | |
| norms = norms.abs() | |
| else: | |
| if spatial_part_only: | |
| x_coords = xj[labels.labels != -1, 1:4] | |
| x_true = target_coords[:, 1:4] | |
| else: | |
| x_coords = xj[labels.labels != -1] | |
| x_true = target_coords | |
| #norms = torch.norm(x_coords - x_true, p=2, dim=1) | |
| # cosine similarity | |
| norms = 2 - (torch.nn.functional.cosine_similarity(x_coords, x_true[:, 1:4], dim=1) + 1) | |
| a = {"norms_loss": torch.mean(norms)} | |
| loss = a["norms_loss"] | |
| if beta_type == "pt+bc": | |
| # TODO: polish this, it's another loss that should be computed outside calc_LV_Lbeta | |
| assert noise_logits is not None | |
| is_noise = labels.labels == -1 | |
| y_true_noise = 1 - is_noise.float() | |
| num_positives = torch.sum(y_true_noise).item() | |
| num_negatives = len(y_true_noise) - num_positives | |
| num_all = len(y_true_noise) | |
| # Compute weights | |
| pos_weight = num_all / num_positives if num_positives > 0 else 0 | |
| neg_weight = num_all / num_negatives if num_negatives > 0 else 0 | |
| weight = pos_weight * y_true_noise + neg_weight * (1 - y_true_noise) | |
| L_bc = torch.nn.BCELoss(weight=weight)( | |
| noise_logits, 1 - is_noise.float() | |
| ) | |
| a["loss_noise_classification"] = L_bc | |
| if beta_type == "pt+bc": | |
| loss += a["loss_noise_classification"] | |
| return loss, a | |