Spaces:
Running
Running
| """ | |
| Standalone single-event MLPF inference. | |
| Provides :func:`run_single_event_inference` which takes raw event data | |
| (from a parquet file or as an awkward record) and model checkpoint paths, | |
| runs the full particle-flow pipeline (graph construction → GATr forward | |
| pass → density-peak clustering → energy correction & PID), and returns: | |
| * a ``pandas.DataFrame`` of predicted particles with their properties | |
| * a hit→cluster mapping as a ``pandas.DataFrame`` | |
| """ | |
| import argparse | |
| import types | |
| from typing import Optional | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import dgl | |
| import awkward as ak | |
| from src.data.fileio import _read_parquet | |
| from src.dataset.functions_graph import create_graph | |
| from src.dataset.functions_particles import Particles_GT, add_batch_number | |
| from src.layers.clustering import DPC_custom_CLD, remove_bad_tracks_from_cluster | |
| from src.utils.pid_conversion import pid_conversion_dict | |
| # -- CPU-compatible attention patch ------------------------------------------ | |
| def _patch_gatr_attention_for_cpu(): | |
| """Replace GATr's xformers-based attention with a naive implementation. | |
| ``xformers.ops.fmha.memory_efficient_attention`` has no CPU kernel, so | |
| running GATr on CPU crashes. This function monkey-patches | |
| ``gatr.primitives.attention.scaled_dot_product_attention`` with a plain | |
| PyTorch implementation that works on any device (albeit slower on GPU). | |
| The patch is applied at most once. | |
| """ | |
| import gatr.primitives.attention as _gatr_attn | |
| if getattr(_gatr_attn, "_cpu_patched", False): | |
| return | |
| def _cpu_sdpa(q, k, v, attn_mask=None): | |
| # q, k, v: (B, H, N, D) — batch, heads, items, dim | |
| B, H, N, D = q.shape | |
| scale = float(D) ** -0.5 | |
| q2 = q.reshape(B * H, N, D) | |
| k2 = k.reshape(B * H, N, D) | |
| v2 = v.reshape(B * H, N, D) | |
| attn = torch.bmm(q2 * scale, k2.transpose(1, 2)) # (B*H, N, N) | |
| if attn_mask is not None: | |
| dense = _block_diag_mask_to_dense(attn_mask, N, q.device) | |
| if dense is not None: | |
| attn = attn.masked_fill(~dense.unsqueeze(0), float("-inf")) | |
| attn = torch.softmax(attn, dim=-1) | |
| # Rows that are fully masked produce NaN after softmax; zero them out. | |
| attn = attn.nan_to_num(0.0) | |
| out = torch.bmm(attn, v2) # (B*H, N, D) | |
| return out.reshape(B, H, N, D) | |
| _gatr_attn.scaled_dot_product_attention = _cpu_sdpa | |
| _gatr_attn._cpu_patched = True | |
| def _block_diag_mask_to_dense(attn_mask, total_len, device): | |
| """Convert an ``xformers.ops.fmha.BlockDiagonalMask`` to a dense bool mask.""" | |
| try: | |
| from xformers.ops.fmha.attn_bias import BlockDiagonalMask | |
| if not isinstance(attn_mask, BlockDiagonalMask): | |
| return None | |
| except ImportError: | |
| return None | |
| # Extract per-sequence start offsets | |
| try: | |
| seqstarts = attn_mask.q_seqinfo.seqstart_py | |
| except AttributeError: | |
| try: | |
| seqstarts = attn_mask.q_seqinfo.seqstart.cpu().tolist() | |
| except Exception: | |
| return None | |
| mask = torch.zeros(total_len, total_len, dtype=torch.bool, device=device) | |
| for i in range(len(seqstarts) - 1): | |
| s, e = seqstarts[i], seqstarts[i + 1] | |
| mask[s:e, s:e] = True | |
| return mask | |
| # -- PID label → human-readable name ---------------------------------------- | |
| _PID_LABELS = { | |
| 0: "electron", | |
| 1: "charged hadron", | |
| 2: "neutral hadron", | |
| 3: "photon", | |
| 4: "muon", | |
| } | |
| _ABS_PDG_NAME = { | |
| 11: "electron", | |
| 13: "muon", | |
| 22: "photon", | |
| 130: "K_L", | |
| 211: "pion±", | |
| 321: "kaon±", | |
| 2112: "neutron", | |
| 2212: "proton", | |
| 310: "K_S", | |
| } | |
| # -- Minimal args namespace for inference ------------------------------------ | |
| def _default_args(**overrides): | |
| """Return a minimal ``argparse.Namespace`` with defaults the model expects.""" | |
| d = dict( | |
| correction=True, | |
| freeze_clustering=True, | |
| predict=True, | |
| pandora=False, | |
| use_gt_clusters=False, | |
| use_average_cc_pos=0.99, | |
| qmin=1.0, | |
| data_config="config_files/config_hits_track_v4.yaml", | |
| network_config="src/models/wrapper/example_mode_gatr_noise.py", | |
| model_prefix="/tmp/mlpf_eval", | |
| start_lr=1e-3, | |
| frac_cluster_loss=0, | |
| local_rank=0, | |
| gpus="0", | |
| batch_size=1, | |
| num_workers=0, | |
| prefetch_factor=1, | |
| num_epochs=1, | |
| steps_per_epoch=None, | |
| samples_per_epoch=None, | |
| steps_per_epoch_val=None, | |
| samples_per_epoch_val=None, | |
| train_val_split=0.8, | |
| data_train=[], | |
| data_val=[], | |
| data_test=[], | |
| data_fraction=1, | |
| file_fraction=1, | |
| fetch_by_files=True, | |
| fetch_step=1, | |
| log_wandb=False, | |
| wandb_displayname="", | |
| wandb_projectname="", | |
| wandb_entity="", | |
| name_output="gradio", | |
| train_batches=100, | |
| ) | |
| d.update(overrides) | |
| return argparse.Namespace(**d) | |
| # -- Model loading ----------------------------------------------------------- | |
| def load_model( | |
| clustering_ckpt: str, | |
| energy_pid_ckpt: Optional[str] = None, | |
| device: str = "cpu", | |
| args_overrides: Optional[dict] = None, | |
| ): | |
| """Load the full MLPF model (clustering + optional energy/PID correction). | |
| Parameters | |
| ---------- | |
| clustering_ckpt : str | |
| Path to the clustering checkpoint (``.ckpt``). | |
| energy_pid_ckpt : str or None | |
| Path to the energy-correction / PID checkpoint (``.ckpt``). | |
| If *None*, only clustering is performed (no energy correction / PID). | |
| device : str | |
| ``"cpu"`` or ``"cuda:0"`` etc. | |
| args_overrides : dict or None | |
| Extra key-value pairs forwarded to :func:`_default_args`. | |
| Returns | |
| ------- | |
| model : ExampleWrapper | |
| The model in eval mode, on *device*. | |
| args : argparse.Namespace | |
| The arguments namespace used. | |
| """ | |
| from src.models.Gatr_pf_e_noise import ExampleWrapper | |
| overrides = dict(args_overrides or {}) | |
| has_correction = energy_pid_ckpt is not None | |
| overrides["correction"] = has_correction | |
| args = _default_args(**overrides) | |
| dev = torch.device(device) | |
| if has_correction: | |
| ckpt = torch.load(energy_pid_ckpt, map_location=dev) | |
| state_dict = ckpt["state_dict"] | |
| model = ExampleWrapper(args=args, dev=0) | |
| model.load_state_dict(state_dict, strict=False) | |
| # Overwrite clustering layers from clustering checkpoint | |
| model2 = ExampleWrapper.load_from_checkpoint( | |
| clustering_ckpt, args=args, dev=0, strict=False, map_location=dev, | |
| ) | |
| model.gatr = model2.gatr | |
| model.ScaledGooeyBatchNorm2_1 = model2.ScaledGooeyBatchNorm2_1 | |
| model.clustering = model2.clustering | |
| model.beta = model2.beta | |
| else: | |
| model = ExampleWrapper.load_from_checkpoint( | |
| clustering_ckpt, args=args, dev=0, strict=False, map_location=dev, | |
| ) | |
| model = model.to(dev) | |
| model.eval() | |
| return model, args | |
| def load_random_model( | |
| device: str = "cpu", | |
| args_overrides: Optional[dict] = None, | |
| ): | |
| """Create a GATr model with randomly initialised weights (no checkpoint). | |
| This is useful for debugging to verify that checkpoint weights are | |
| actually being loaded and used by the model. | |
| Parameters | |
| ---------- | |
| device : str | |
| ``"cpu"`` or ``"cuda:0"`` etc. | |
| args_overrides : dict or None | |
| Extra key-value pairs forwarded to :func:`_default_args`. | |
| Returns | |
| ------- | |
| model : ExampleWrapper | |
| The model (random weights) in eval mode, on *device*. | |
| args : argparse.Namespace | |
| The arguments namespace used. | |
| """ | |
| from src.models.Gatr_pf_e_noise import ExampleWrapper | |
| overrides = dict(args_overrides or {}) | |
| overrides["correction"] = False | |
| args = _default_args(**overrides) | |
| dev = torch.device(device) | |
| model = ExampleWrapper(args=args, dev=0) | |
| model = model.to(dev) | |
| model.eval() | |
| return model, args | |
| # -- Single-event data loading ----------------------------------------------- | |
| def load_event_from_parquet(parquet_path: str, event_index: int = 0): | |
| """Read a single event from a parquet file. | |
| Returns an awkward record with fields ``X_hit``, ``X_track``, ``X_gen``, | |
| ``ygen_hit``, ``ygen_track``, etc. | |
| """ | |
| table = _read_parquet(parquet_path) | |
| n_events = len(table["X_track"]) | |
| if event_index >= n_events: | |
| raise IndexError( | |
| f"event_index {event_index} out of range (file has {n_events} events)" | |
| ) | |
| event = {field: table[field][event_index] for field in table.fields} | |
| return event | |
| # -- Core inference function -------------------------------------------------- | |
| def run_single_event_inference( | |
| event, | |
| model, | |
| args, | |
| device: str = "cpu", | |
| ): | |
| """Run full MLPF inference on a single event. | |
| Parameters | |
| ---------- | |
| event : dict-like | |
| A single event record (from :func:`load_event_from_parquet`). | |
| model : ExampleWrapper | |
| The loaded model (from :func:`load_model`). | |
| args : argparse.Namespace | |
| The arguments namespace (from :func:`load_model`). | |
| device : str | |
| Device string. | |
| Returns | |
| ------- | |
| particles_df : pandas.DataFrame | |
| One row per predicted particle with columns: | |
| ``cluster_id``, ``energy``, ``pid_class``, ``pid_label``, | |
| ``px``, ``py``, ``pz``, ``is_charged``. | |
| hit_cluster_df : pandas.DataFrame | |
| One row per hit with columns: | |
| ``hit_index``, ``cluster_id``, ``pandora_cluster_id``, | |
| ``hit_type_id``, ``hit_type``, ``x``, ``y``, ``z``, | |
| ``hit_energy``, ``cluster_x``, ``cluster_y``, ``cluster_z``. | |
| ``pandora_cluster_id`` is -1 when pandora data is not available | |
| or when the hit has no matching entry (e.g. CSV was modified after | |
| loading from parquet). | |
| mc_particles_df : pandas.DataFrame | |
| One row per MC truth particle with columns: | |
| ``pid``, ``energy``, ``momentum``, ``px``, ``py``, ``pz``, | |
| ``mass``, ``theta``, ``phi``, ``vx``, ``vy``, ``vz``, | |
| ``gen_status``, ``pdg_name``. | |
| pandora_particles_df : pandas.DataFrame | |
| One row per Pandora PFO with columns: | |
| ``pfo_idx``, ``pid``, ``pdg_name``, ``energy``, ``momentum``, | |
| ``px``, ``py``, ``pz``, ``ref_x``, ``ref_y``, ``ref_z``. | |
| Empty when pandora data is not available in the input. | |
| """ | |
| dev = torch.device(device) | |
| # Ensure eval mode so that BatchNorm layers use running statistics from | |
| # training instead of computing batch statistics from the current | |
| # (single-event) input. Without this, inference with batch_size=1 | |
| # produces incorrect normalization. | |
| model.eval() | |
| if dev.type == "cpu": | |
| _patch_gatr_attention_for_cpu() | |
| # 0. Extract MC truth particles table and pandora particles | |
| mc_particles_df = _extract_mc_particles(event) | |
| pandora_particles_df, pfo_calohit, pfo_track = _extract_pandora_particles(event) | |
| # 1. Build DGL graph from the event | |
| [g, y_data], graph_empty = create_graph(event, for_training=False, args=args) | |
| if graph_empty: | |
| return pd.DataFrame(), pd.DataFrame(), mc_particles_df, pandora_particles_df | |
| g = g.to(dev) | |
| # Prepare batch metadata expected by the model | |
| y_data.batch_number = torch.zeros(y_data.E.shape[0], 1) | |
| # 2. Forward pass through the GATr clustering backbone | |
| inputs = g.ndata["pos_hits_xyz"].float().to(dev) | |
| inputs_scalar = g.ndata["hit_type"].float().view(-1, 1).to(dev) | |
| from gatr.interface import embed_point, embed_scalar | |
| from xformers.ops.fmha import BlockDiagonalMask | |
| inputs_normed = model.ScaledGooeyBatchNorm2_1(inputs) | |
| embedded_inputs = embed_point(inputs_normed) + embed_scalar(inputs_scalar) | |
| embedded_inputs = embedded_inputs.unsqueeze(-2) | |
| mask = BlockDiagonalMask.from_seqlens([g.num_nodes()]) | |
| scalars = torch.cat( | |
| (g.ndata["e_hits"].float().to(dev), g.ndata["p_hits"].float().to(dev)), dim=1 | |
| ) | |
| from gatr.interface import extract_point, extract_scalar | |
| embedded_outputs, scalar_outputs = model.gatr( | |
| embedded_inputs, scalars=scalars, attention_mask=mask | |
| ) | |
| points = extract_point(embedded_outputs[:, 0, :]) | |
| nodewise_outputs = extract_scalar(embedded_outputs) | |
| x_point = points | |
| x_scalar = torch.cat( | |
| (nodewise_outputs.view(-1, 1), scalar_outputs.view(-1, 1)), dim=1 | |
| ) | |
| x_cluster_coord = model.clustering(x_point) | |
| beta = model.beta(x_scalar) | |
| g.ndata["final_cluster"] = x_cluster_coord | |
| g.ndata["beta"] = beta.view(-1) | |
| # 3. Density-peak clustering | |
| labels = DPC_custom_CLD(x_cluster_coord, g, dev) | |
| labels, _ = remove_bad_tracks_from_cluster(g, labels) | |
| # 4. Build hit→cluster table | |
| n_hits = g.num_nodes() | |
| hit_types_raw = g.ndata["hit_type"].cpu().numpy() | |
| hit_type_names = {1: "track", 2: "ECAL", 3: "HCAL", 4: "muon"} | |
| # Build pandora cluster ID per node (hits first, then tracks) | |
| # Use min of array lengths for graceful handling when CSV was modified | |
| n_calo = len(np.asarray(event.get("X_hit", []))) | |
| pandora_cluster_ids = np.full(n_hits, -1, dtype=np.int64) | |
| if len(pfo_calohit) > 0: | |
| n_assign = min(len(pfo_calohit), n_calo) | |
| pandora_cluster_ids[:n_assign] = pfo_calohit[:n_assign] | |
| n_tracks = n_hits - n_calo | |
| if n_tracks > 0 and len(pfo_track) > 0: | |
| n_assign = min(len(pfo_track), n_tracks) | |
| pandora_cluster_ids[n_calo:n_calo + n_assign] = pfo_track[:n_assign] | |
| hit_cluster_df = pd.DataFrame({ | |
| "hit_index": np.arange(n_hits), | |
| "cluster_id": labels.cpu().numpy(), | |
| "pandora_cluster_id": pandora_cluster_ids, | |
| "hit_type_id": hit_types_raw, | |
| "hit_type": [hit_type_names.get(int(t), str(int(t))) for t in hit_types_raw], | |
| "x": g.ndata["pos_hits_xyz"][:, 0].cpu().numpy(), | |
| "y": g.ndata["pos_hits_xyz"][:, 1].cpu().numpy(), | |
| "z": g.ndata["pos_hits_xyz"][:, 2].cpu().numpy(), | |
| "hit_energy": g.ndata["e_hits"].view(-1).cpu().numpy(), | |
| "cluster_x": x_cluster_coord[:, 0].cpu().numpy(), | |
| "cluster_y": x_cluster_coord[:, 1].cpu().numpy(), | |
| "cluster_z": x_cluster_coord[:, 2].cpu().numpy(), | |
| }) | |
| # 5. Per-cluster summary (basic, before energy correction) | |
| unique_labels = torch.unique(labels) | |
| # cluster 0 = noise | |
| cluster_ids = unique_labels[unique_labels > 0].cpu().numpy() | |
| from torch_scatter import scatter_add | |
| e_per_cluster = scatter_add( | |
| g.ndata["e_hits"].view(-1).to(dev), labels.to(dev) | |
| ) | |
| p_per_cluster = scatter_add( | |
| g.ndata["p_hits"].view(-1).to(dev), labels.to(dev) | |
| ) | |
| n_hits_per_cluster = scatter_add( | |
| torch.ones(n_hits, device=dev), labels.to(dev) | |
| ) | |
| # Check if any cluster has a track (→ charged) | |
| is_track_per_cluster = scatter_add( | |
| (g.ndata["hit_type"].to(dev) == 1).float(), labels.to(dev) | |
| ) | |
| rows = [] | |
| for cid in cluster_ids: | |
| mask_c = labels == cid | |
| e_sum = e_per_cluster[cid].item() | |
| p_sum = p_per_cluster[cid].item() | |
| n_h = int(n_hits_per_cluster[cid].item()) | |
| has_track = is_track_per_cluster[cid].item() >= 1 | |
| # Mean position | |
| pos_mean = g.ndata["pos_hits_xyz"][mask_c].mean(dim=0).cpu().numpy() | |
| rows.append({ | |
| "cluster_id": int(cid), | |
| "energy_sum_hits": round(e_sum, 4), | |
| "p_track": round(p_sum, 4) if has_track else 0.0, | |
| "n_hits": n_h, | |
| "is_charged": has_track, | |
| "mean_x": round(float(pos_mean[0]), 2), | |
| "mean_y": round(float(pos_mean[1]), 2), | |
| "mean_z": round(float(pos_mean[2]), 2), | |
| }) | |
| particles_df = pd.DataFrame(rows) | |
| # 6. If energy correction is available, run it | |
| if args.correction and hasattr(model, "energy_correction"): | |
| try: | |
| particles_df = _run_energy_correction( | |
| model, g, x_cluster_coord, beta, labels, y_data, particles_df, dev | |
| ) | |
| except Exception as e: | |
| # Attach a note but don't crash – the basic table is still useful | |
| particles_df["note"] = f"Energy correction failed: {e}" | |
| return particles_df, hit_cluster_df, mc_particles_df, pandora_particles_df | |
| def _extract_mc_particles(event): | |
| """Build a DataFrame of MC truth particles from the event's ``X_gen``.""" | |
| x_gen = np.asarray(event.get("X_gen", [])) | |
| if x_gen.ndim != 2 or x_gen.shape[0] == 0 or x_gen.shape[1] < 18: | |
| return pd.DataFrame() | |
| rows = [] | |
| for i in range(x_gen.shape[0]): | |
| pid_raw = int(x_gen[i, 0]) | |
| rows.append({ | |
| "particle_idx": i, | |
| "pid": pid_raw, | |
| "pdg_name": _ABS_PDG_NAME.get(abs(pid_raw), str(pid_raw)), | |
| "gen_status": int(x_gen[i, 1]), | |
| "energy": round(float(x_gen[i, 8]), 4), | |
| "momentum": round(float(x_gen[i, 11]), 4), | |
| "px": round(float(x_gen[i, 12]), 4), | |
| "py": round(float(x_gen[i, 13]), 4), | |
| "pz": round(float(x_gen[i, 14]), 4), | |
| "mass": round(float(x_gen[i, 10]), 4), | |
| "theta": round(float(x_gen[i, 4]), 4), | |
| "phi": round(float(x_gen[i, 5]), 4), | |
| "vx": round(float(x_gen[i, 15]), 4), | |
| "vy": round(float(x_gen[i, 16]), 4), | |
| "vz": round(float(x_gen[i, 17]), 4), | |
| }) | |
| return pd.DataFrame(rows) | |
| def _extract_pandora_particles(event): | |
| """Build a DataFrame of Pandora PFO particles from the event's ``X_pandora``. | |
| ``X_pandora`` columns (per PFO): | |
| 0: pid (PDG ID) | |
| 1–3: px, py, pz (momentum components at reference point) | |
| 4–6: ref_x, ref_y, ref_z (reference point) | |
| 7: energy | |
| 8: momentum magnitude | |
| Returns (pandora_particles_df, pfo_hit_links, pfo_track_links) where | |
| *pfo_hit_links* and *pfo_track_links* are integer arrays mapping each | |
| hit/track to a PFO index (0-based, -1 = unassigned). | |
| """ | |
| x_pandora = np.asarray(event.get("X_pandora", [])) | |
| pfo_calohit = np.asarray(event.get("pfo_calohit", []), dtype=np.int64) | |
| pfo_track = np.asarray(event.get("pfo_track", []), dtype=np.int64) | |
| if x_pandora.ndim != 2 or x_pandora.shape[0] == 0 or x_pandora.shape[1] < 9: | |
| return pd.DataFrame(), pfo_calohit, pfo_track | |
| rows = [] | |
| for i in range(x_pandora.shape[0]): | |
| pid_raw = int(x_pandora[i, 0]) | |
| rows.append({ | |
| "pfo_idx": i, | |
| "pid": pid_raw, | |
| "pdg_name": _ABS_PDG_NAME.get(abs(pid_raw), str(pid_raw)), | |
| "energy": round(float(x_pandora[i, 7]), 4), | |
| "momentum": round(float(x_pandora[i, 8]), 4), | |
| "px": round(float(x_pandora[i, 1]), 4), | |
| "py": round(float(x_pandora[i, 2]), 4), | |
| "pz": round(float(x_pandora[i, 3]), 4), | |
| "ref_x": round(float(x_pandora[i, 4]), 2), | |
| "ref_y": round(float(x_pandora[i, 5]), 2), | |
| "ref_z": round(float(x_pandora[i, 6]), 2), | |
| }) | |
| return pd.DataFrame(rows), pfo_calohit, pfo_track | |
| def _run_energy_correction(model, g, x_cluster_coord, beta, labels, y_data, particles_df, dev): | |
| """Run the energy correction & PID branch and enrich *particles_df*.""" | |
| from src.layers.shower_matching import match_showers, obtain_intersection_matrix, obtain_union_matrix | |
| from torch_scatter import scatter_add, scatter_mean | |
| from src.utils.post_clustering_features import ( | |
| get_post_clustering_features, get_extra_features, calculate_eta, calculate_phi, | |
| ) | |
| x = torch.cat((x_cluster_coord, beta.view(-1, 1)), dim=1) | |
| # Re-create per-cluster sub-graphs expected by the correction pipeline | |
| particle_ids = torch.unique(g.ndata["particle_number"]) | |
| shower_p_unique = torch.unique(labels) | |
| model_output_dummy = x # used only for device by match_showers | |
| shower_p_unique_m, row_ind, col_ind, i_m_w, _ = match_showers( | |
| labels, {"graph": g, "part_true": y_data}, | |
| particle_ids, model_output_dummy, 0, 0, None, | |
| ) | |
| row_ind = torch.Tensor(row_ind).to(dev).long() | |
| col_ind = torch.Tensor(col_ind).to(dev).long() | |
| if torch.sum(particle_ids == 0) > 0: | |
| row_ind_ = row_ind - 1 | |
| else: | |
| row_ind_ = row_ind | |
| index_matches = (col_ind + 1).to(dev).long() | |
| # Build per-cluster sub-graphs (matched + fakes) | |
| graphs_matched = [] | |
| true_energies = [] | |
| reco_energies = [] | |
| pids_matched = [] | |
| coords_matched = [] | |
| e_true_daughters = [] | |
| for j, sh_label in enumerate(index_matches): | |
| if torch.sum(sh_label == index_matches) == 1: | |
| mask = labels == sh_label | |
| sg = dgl.graph(([], [])) | |
| sg.add_nodes(int(mask.sum())) | |
| sg = sg.to(dev) | |
| sg.ndata["h"] = g.ndata["h"][mask] | |
| if "pos_pxpypz" in g.ndata: | |
| sg.ndata["pos_pxpypz"] = g.ndata["pos_pxpypz"][mask] | |
| if "pos_pxpypz_at_vertex" in g.ndata: | |
| sg.ndata["pos_pxpypz_at_vertex"] = g.ndata["pos_pxpypz_at_vertex"][mask] | |
| sg.ndata["chi_squared_tracks"] = g.ndata["chi_squared_tracks"][mask] | |
| energy_t = y_data.E.to(dev) | |
| true_e = energy_t[row_ind_[j]] | |
| pids_matched.append(y_data.pid[row_ind_[j]].item()) | |
| coords_matched.append(y_data.coord[row_ind_[j]].detach().cpu().numpy()) | |
| e_true_daughters.append(y_data.m[row_ind_[j]].to(dev)) | |
| reco_e = torch.sum(g.ndata["e_hits"].view(-1).to(dev)[mask]) | |
| graphs_matched.append(sg) | |
| true_energies.append(true_e.view(-1)) | |
| reco_energies.append(reco_e.view(-1)) | |
| # Add fakes | |
| pred_showers = shower_p_unique_m.clone() | |
| pred_showers[index_matches] = -1 | |
| pred_showers[0] = -1 | |
| fakes_mask = pred_showers != -1 | |
| fakes_idx = torch.where(fakes_mask)[0] | |
| graphs_fakes = [] | |
| reco_fakes = [] | |
| for fi in fakes_idx: | |
| mask = labels == fi | |
| sg = dgl.graph(([], [])) | |
| sg.add_nodes(int(mask.sum())) | |
| sg = sg.to(dev) | |
| sg.ndata["h"] = g.ndata["h"][mask] | |
| if "pos_pxpypz" in g.ndata: | |
| sg.ndata["pos_pxpypz"] = g.ndata["pos_pxpypz"][mask] | |
| if "pos_pxpypz_at_vertex" in g.ndata: | |
| sg.ndata["pos_pxpypz_at_vertex"] = g.ndata["pos_pxpypz_at_vertex"][mask] | |
| sg.ndata["chi_squared_tracks"] = g.ndata["chi_squared_tracks"][mask] | |
| graphs_fakes.append(sg) | |
| reco_fakes.append(torch.sum(g.ndata["e_hits"].view(-1).to(dev)[mask]).view(-1)) | |
| if not graphs_matched and not graphs_fakes: | |
| return particles_df | |
| all_graphs = dgl.batch(graphs_matched + graphs_fakes) | |
| sum_e = torch.cat(reco_energies + reco_fakes, dim=0) | |
| # Compute high-level features | |
| batch_num_nodes = all_graphs.batch_num_nodes() | |
| batch_idx = [] | |
| for i, n in enumerate(batch_num_nodes): | |
| batch_idx.extend([i] * n) | |
| batch_idx = torch.tensor(batch_idx).to(dev) | |
| all_graphs.ndata["h"][:, 0:3] = all_graphs.ndata["h"][:, 0:3] / 3300 | |
| graphs_sum_features = scatter_add(all_graphs.ndata["h"], batch_idx, dim=0) | |
| graphs_sum_features = graphs_sum_features[batch_idx] | |
| betas = torch.sigmoid(all_graphs.ndata["h"][:, -1]) | |
| all_graphs.ndata["h"] = torch.cat( | |
| (all_graphs.ndata["h"], graphs_sum_features), dim=1 | |
| ) | |
| high_level = get_post_clustering_features(all_graphs, sum_e) | |
| extra_features = get_extra_features(all_graphs, betas) | |
| n_clusters = high_level.shape[0] | |
| pred_energy = torch.ones(n_clusters, device=dev) | |
| pred_pos = torch.ones(n_clusters, 3, device=dev) | |
| pred_pid = torch.ones(n_clusters, device=dev).long() | |
| node_features_avg = scatter_mean(all_graphs.ndata["h"], batch_idx, dim=0)[:, 0:3] | |
| eta = calculate_eta(node_features_avg[:, 0], node_features_avg[:, 1], node_features_avg[:, 2]) | |
| phi = calculate_phi(node_features_avg[:, 0], node_features_avg[:, 1]) | |
| high_level = torch.cat( | |
| (high_level, node_features_avg, eta.view(-1, 1), phi.view(-1, 1)), dim=1 | |
| ) | |
| num_tracks = high_level[:, 7] | |
| charged_idx = torch.where(num_tracks >= 1)[0] | |
| neutral_idx = torch.where(num_tracks < 1)[0] | |
| def zero_nans(t): | |
| out = t.clone() | |
| out[out != out] = 0 | |
| return out | |
| feats_charged = zero_nans(high_level[charged_idx]) | |
| feats_neutral = zero_nans(high_level[neutral_idx]) | |
| # Run charged prediction | |
| charged_energies = model.energy_correction.model_charged.charged_prediction( | |
| all_graphs, charged_idx, feats_charged, | |
| ) | |
| # Run neutral prediction | |
| neutral_energies, neutral_pxyz_avg = model.energy_correction.model_neutral.neutral_prediction( | |
| all_graphs, neutral_idx, feats_neutral, | |
| ) | |
| pids_charged = model.energy_correction.pids_charged | |
| pids_neutral = model.energy_correction.pids_neutral | |
| if len(pids_charged): | |
| ch_e, ch_pos, ch_pid_logits, ch_ref = charged_energies | |
| else: | |
| ch_e, ch_pos, _ = charged_energies | |
| ch_pid_logits = None | |
| if len(pids_neutral): | |
| ne_e, ne_pos, ne_pid_logits, ne_ref = neutral_energies | |
| else: | |
| ne_e, ne_pos, _ = neutral_energies | |
| ne_pid_logits = None | |
| pred_energy[charged_idx.flatten()] = ch_e if len(charged_idx) else pred_energy[charged_idx.flatten()] | |
| pred_energy[neutral_idx.flatten()] = ne_e if len(neutral_idx) else pred_energy[neutral_idx.flatten()] | |
| if ch_pid_logits is not None and len(charged_idx): | |
| ch_labels = np.array(pids_charged)[np.argmax(ch_pid_logits.cpu().detach().numpy(), axis=1)] | |
| pred_pid[charged_idx.flatten()] = torch.tensor(ch_labels).long().to(dev) | |
| if ne_pid_logits is not None and len(neutral_idx): | |
| ne_labels = np.array(pids_neutral)[np.argmax(ne_pid_logits.cpu().detach().numpy(), axis=1)] | |
| pred_pid[neutral_idx.flatten()] = torch.tensor(ne_labels).long().to(dev) | |
| pred_energy[pred_energy < 0] = 0.0 | |
| # Direction | |
| if len(charged_idx): | |
| pred_pos[charged_idx.flatten()] = ch_pos.float().to(dev) | |
| if len(neutral_idx): | |
| pred_pos[neutral_idx.flatten()] = ne_pos.float().to(dev) | |
| # Build enriched output DataFrame | |
| n_matched = len(graphs_matched) | |
| rows = [] | |
| for k in range(n_clusters): | |
| is_fake = k >= n_matched | |
| pid_cls = int(pred_pid[k].item()) | |
| rows.append({ | |
| "cluster_id": k + 1, | |
| "corrected_energy": round(pred_energy[k].item(), 4), | |
| "raw_energy": round(sum_e[k].item(), 4), | |
| "pid_class": pid_cls, | |
| "pid_label": _PID_LABELS.get(pid_cls, str(pid_cls)), | |
| "px": round(pred_pos[k, 0].item(), 4), | |
| "py": round(pred_pos[k, 1].item(), 4), | |
| "pz": round(pred_pos[k, 2].item(), 4), | |
| "is_charged": bool(k in charged_idx), | |
| "is_fake": is_fake, | |
| }) | |
| return pd.DataFrame(rows) | |