""" Standalone single-event MLPF inference. Provides :func:`run_single_event_inference` which takes raw event data (from a parquet file or as an awkward record) and model checkpoint paths, runs the full particle-flow pipeline (graph construction → GATr forward pass → density-peak clustering → energy correction & PID), and returns: * a ``pandas.DataFrame`` of predicted particles with their properties * a hit→cluster mapping as a ``pandas.DataFrame`` """ import argparse import types from typing import Optional import numpy as np import pandas as pd import torch import dgl import awkward as ak from src.data.fileio import _read_parquet from src.dataset.functions_graph import create_graph from src.dataset.functions_particles import Particles_GT, add_batch_number from src.layers.clustering import DPC_custom_CLD, remove_bad_tracks_from_cluster from src.utils.pid_conversion import pid_conversion_dict # -- CPU-compatible attention patch ------------------------------------------ def _patch_gatr_attention_for_cpu(): """Replace GATr's xformers-based attention with a naive implementation. ``xformers.ops.fmha.memory_efficient_attention`` has no CPU kernel, so running GATr on CPU crashes. This function monkey-patches ``gatr.primitives.attention.scaled_dot_product_attention`` with a plain PyTorch implementation that works on any device (albeit slower on GPU). The patch is applied at most once. """ import gatr.primitives.attention as _gatr_attn if getattr(_gatr_attn, "_cpu_patched", False): return def _cpu_sdpa(q, k, v, attn_mask=None): # q, k, v: (B, H, N, D) — batch, heads, items, dim B, H, N, D = q.shape scale = float(D) ** -0.5 q2 = q.reshape(B * H, N, D) k2 = k.reshape(B * H, N, D) v2 = v.reshape(B * H, N, D) attn = torch.bmm(q2 * scale, k2.transpose(1, 2)) # (B*H, N, N) if attn_mask is not None: dense = _block_diag_mask_to_dense(attn_mask, N, q.device) if dense is not None: attn = attn.masked_fill(~dense.unsqueeze(0), float("-inf")) attn = torch.softmax(attn, dim=-1) # Rows that are fully masked produce NaN after softmax; zero them out. attn = attn.nan_to_num(0.0) out = torch.bmm(attn, v2) # (B*H, N, D) return out.reshape(B, H, N, D) _gatr_attn.scaled_dot_product_attention = _cpu_sdpa _gatr_attn._cpu_patched = True def _block_diag_mask_to_dense(attn_mask, total_len, device): """Convert an ``xformers.ops.fmha.BlockDiagonalMask`` to a dense bool mask.""" try: from xformers.ops.fmha.attn_bias import BlockDiagonalMask if not isinstance(attn_mask, BlockDiagonalMask): return None except ImportError: return None # Extract per-sequence start offsets try: seqstarts = attn_mask.q_seqinfo.seqstart_py except AttributeError: try: seqstarts = attn_mask.q_seqinfo.seqstart.cpu().tolist() except Exception: return None mask = torch.zeros(total_len, total_len, dtype=torch.bool, device=device) for i in range(len(seqstarts) - 1): s, e = seqstarts[i], seqstarts[i + 1] mask[s:e, s:e] = True return mask # -- PID label → human-readable name ---------------------------------------- _PID_LABELS = { 0: "electron", 1: "charged hadron", 2: "neutral hadron", 3: "photon", 4: "muon", } _ABS_PDG_NAME = { 11: "electron", 13: "muon", 22: "photon", 130: "K_L", 211: "pion±", 321: "kaon±", 2112: "neutron", 2212: "proton", 310: "K_S", } # -- Minimal args namespace for inference ------------------------------------ def _default_args(**overrides): """Return a minimal ``argparse.Namespace`` with defaults the model expects.""" d = dict( correction=True, freeze_clustering=True, predict=True, pandora=False, use_gt_clusters=False, use_average_cc_pos=0.99, qmin=1.0, data_config="config_files/config_hits_track_v4.yaml", network_config="src/models/wrapper/example_mode_gatr_noise.py", model_prefix="/tmp/mlpf_eval", start_lr=1e-3, frac_cluster_loss=0, local_rank=0, gpus="0", batch_size=1, num_workers=0, prefetch_factor=1, num_epochs=1, steps_per_epoch=None, samples_per_epoch=None, steps_per_epoch_val=None, samples_per_epoch_val=None, train_val_split=0.8, data_train=[], data_val=[], data_test=[], data_fraction=1, file_fraction=1, fetch_by_files=True, fetch_step=1, log_wandb=False, wandb_displayname="", wandb_projectname="", wandb_entity="", name_output="gradio", train_batches=100, ) d.update(overrides) return argparse.Namespace(**d) # -- Model loading ----------------------------------------------------------- def load_model( clustering_ckpt: str, energy_pid_ckpt: Optional[str] = None, device: str = "cpu", args_overrides: Optional[dict] = None, ): """Load the full MLPF model (clustering + optional energy/PID correction). Parameters ---------- clustering_ckpt : str Path to the clustering checkpoint (``.ckpt``). energy_pid_ckpt : str or None Path to the energy-correction / PID checkpoint (``.ckpt``). If *None*, only clustering is performed (no energy correction / PID). device : str ``"cpu"`` or ``"cuda:0"`` etc. args_overrides : dict or None Extra key-value pairs forwarded to :func:`_default_args`. Returns ------- model : ExampleWrapper The model in eval mode, on *device*. args : argparse.Namespace The arguments namespace used. """ from src.models.Gatr_pf_e_noise import ExampleWrapper overrides = dict(args_overrides or {}) has_correction = energy_pid_ckpt is not None overrides["correction"] = has_correction args = _default_args(**overrides) dev = torch.device(device) if has_correction: ckpt = torch.load(energy_pid_ckpt, map_location=dev) state_dict = ckpt["state_dict"] model = ExampleWrapper(args=args, dev=0) model.load_state_dict(state_dict, strict=False) # Overwrite clustering layers from clustering checkpoint model2 = ExampleWrapper.load_from_checkpoint( clustering_ckpt, args=args, dev=0, strict=False, map_location=dev, ) model.gatr = model2.gatr model.ScaledGooeyBatchNorm2_1 = model2.ScaledGooeyBatchNorm2_1 model.clustering = model2.clustering model.beta = model2.beta else: model = ExampleWrapper.load_from_checkpoint( clustering_ckpt, args=args, dev=0, strict=False, map_location=dev, ) model = model.to(dev) model.eval() return model, args def load_random_model( device: str = "cpu", args_overrides: Optional[dict] = None, ): """Create a GATr model with randomly initialised weights (no checkpoint). This is useful for debugging to verify that checkpoint weights are actually being loaded and used by the model. Parameters ---------- device : str ``"cpu"`` or ``"cuda:0"`` etc. args_overrides : dict or None Extra key-value pairs forwarded to :func:`_default_args`. Returns ------- model : ExampleWrapper The model (random weights) in eval mode, on *device*. args : argparse.Namespace The arguments namespace used. """ from src.models.Gatr_pf_e_noise import ExampleWrapper overrides = dict(args_overrides or {}) overrides["correction"] = False args = _default_args(**overrides) dev = torch.device(device) model = ExampleWrapper(args=args, dev=0) model = model.to(dev) model.eval() return model, args # -- Single-event data loading ----------------------------------------------- def load_event_from_parquet(parquet_path: str, event_index: int = 0): """Read a single event from a parquet file. Returns an awkward record with fields ``X_hit``, ``X_track``, ``X_gen``, ``ygen_hit``, ``ygen_track``, etc. """ table = _read_parquet(parquet_path) n_events = len(table["X_track"]) if event_index >= n_events: raise IndexError( f"event_index {event_index} out of range (file has {n_events} events)" ) event = {field: table[field][event_index] for field in table.fields} return event # -- Core inference function -------------------------------------------------- @torch.no_grad() def run_single_event_inference( event, model, args, device: str = "cpu", ): """Run full MLPF inference on a single event. Parameters ---------- event : dict-like A single event record (from :func:`load_event_from_parquet`). model : ExampleWrapper The loaded model (from :func:`load_model`). args : argparse.Namespace The arguments namespace (from :func:`load_model`). device : str Device string. Returns ------- particles_df : pandas.DataFrame One row per predicted particle with columns: ``cluster_id``, ``energy``, ``pid_class``, ``pid_label``, ``px``, ``py``, ``pz``, ``is_charged``. hit_cluster_df : pandas.DataFrame One row per hit with columns: ``hit_index``, ``cluster_id``, ``pandora_cluster_id``, ``hit_type_id``, ``hit_type``, ``x``, ``y``, ``z``, ``hit_energy``, ``cluster_x``, ``cluster_y``, ``cluster_z``. ``pandora_cluster_id`` is -1 when pandora data is not available or when the hit has no matching entry (e.g. CSV was modified after loading from parquet). mc_particles_df : pandas.DataFrame One row per MC truth particle with columns: ``pid``, ``energy``, ``momentum``, ``px``, ``py``, ``pz``, ``mass``, ``theta``, ``phi``, ``vx``, ``vy``, ``vz``, ``gen_status``, ``pdg_name``. pandora_particles_df : pandas.DataFrame One row per Pandora PFO with columns: ``pfo_idx``, ``pid``, ``pdg_name``, ``energy``, ``momentum``, ``px``, ``py``, ``pz``, ``ref_x``, ``ref_y``, ``ref_z``. Empty when pandora data is not available in the input. """ dev = torch.device(device) # Ensure eval mode so that BatchNorm layers use running statistics from # training instead of computing batch statistics from the current # (single-event) input. Without this, inference with batch_size=1 # produces incorrect normalization. model.eval() if dev.type == "cpu": _patch_gatr_attention_for_cpu() # 0. Extract MC truth particles table and pandora particles mc_particles_df = _extract_mc_particles(event) pandora_particles_df, pfo_calohit, pfo_track = _extract_pandora_particles(event) # 1. Build DGL graph from the event [g, y_data], graph_empty = create_graph(event, for_training=False, args=args) if graph_empty: return pd.DataFrame(), pd.DataFrame(), mc_particles_df, pandora_particles_df g = g.to(dev) # Prepare batch metadata expected by the model y_data.batch_number = torch.zeros(y_data.E.shape[0], 1) # 2. Forward pass through the GATr clustering backbone inputs = g.ndata["pos_hits_xyz"].float().to(dev) inputs_scalar = g.ndata["hit_type"].float().view(-1, 1).to(dev) from gatr.interface import embed_point, embed_scalar from xformers.ops.fmha import BlockDiagonalMask inputs_normed = model.ScaledGooeyBatchNorm2_1(inputs) embedded_inputs = embed_point(inputs_normed) + embed_scalar(inputs_scalar) embedded_inputs = embedded_inputs.unsqueeze(-2) mask = BlockDiagonalMask.from_seqlens([g.num_nodes()]) scalars = torch.cat( (g.ndata["e_hits"].float().to(dev), g.ndata["p_hits"].float().to(dev)), dim=1 ) from gatr.interface import extract_point, extract_scalar embedded_outputs, scalar_outputs = model.gatr( embedded_inputs, scalars=scalars, attention_mask=mask ) points = extract_point(embedded_outputs[:, 0, :]) nodewise_outputs = extract_scalar(embedded_outputs) x_point = points x_scalar = torch.cat( (nodewise_outputs.view(-1, 1), scalar_outputs.view(-1, 1)), dim=1 ) x_cluster_coord = model.clustering(x_point) beta = model.beta(x_scalar) g.ndata["final_cluster"] = x_cluster_coord g.ndata["beta"] = beta.view(-1) # 3. Density-peak clustering labels = DPC_custom_CLD(x_cluster_coord, g, dev) labels, _ = remove_bad_tracks_from_cluster(g, labels) # 4. Build hit→cluster table n_hits = g.num_nodes() hit_types_raw = g.ndata["hit_type"].cpu().numpy() hit_type_names = {1: "track", 2: "ECAL", 3: "HCAL", 4: "muon"} # Build pandora cluster ID per node (hits first, then tracks) # Use min of array lengths for graceful handling when CSV was modified n_calo = len(np.asarray(event.get("X_hit", []))) pandora_cluster_ids = np.full(n_hits, -1, dtype=np.int64) if len(pfo_calohit) > 0: n_assign = min(len(pfo_calohit), n_calo) pandora_cluster_ids[:n_assign] = pfo_calohit[:n_assign] n_tracks = n_hits - n_calo if n_tracks > 0 and len(pfo_track) > 0: n_assign = min(len(pfo_track), n_tracks) pandora_cluster_ids[n_calo:n_calo + n_assign] = pfo_track[:n_assign] hit_cluster_df = pd.DataFrame({ "hit_index": np.arange(n_hits), "cluster_id": labels.cpu().numpy(), "pandora_cluster_id": pandora_cluster_ids, "hit_type_id": hit_types_raw, "hit_type": [hit_type_names.get(int(t), str(int(t))) for t in hit_types_raw], "x": g.ndata["pos_hits_xyz"][:, 0].cpu().numpy(), "y": g.ndata["pos_hits_xyz"][:, 1].cpu().numpy(), "z": g.ndata["pos_hits_xyz"][:, 2].cpu().numpy(), "hit_energy": g.ndata["e_hits"].view(-1).cpu().numpy(), "cluster_x": x_cluster_coord[:, 0].cpu().numpy(), "cluster_y": x_cluster_coord[:, 1].cpu().numpy(), "cluster_z": x_cluster_coord[:, 2].cpu().numpy(), }) # 5. Per-cluster summary (basic, before energy correction) unique_labels = torch.unique(labels) # cluster 0 = noise cluster_ids = unique_labels[unique_labels > 0].cpu().numpy() from torch_scatter import scatter_add e_per_cluster = scatter_add( g.ndata["e_hits"].view(-1).to(dev), labels.to(dev) ) p_per_cluster = scatter_add( g.ndata["p_hits"].view(-1).to(dev), labels.to(dev) ) n_hits_per_cluster = scatter_add( torch.ones(n_hits, device=dev), labels.to(dev) ) # Check if any cluster has a track (→ charged) is_track_per_cluster = scatter_add( (g.ndata["hit_type"].to(dev) == 1).float(), labels.to(dev) ) rows = [] for cid in cluster_ids: mask_c = labels == cid e_sum = e_per_cluster[cid].item() p_sum = p_per_cluster[cid].item() n_h = int(n_hits_per_cluster[cid].item()) has_track = is_track_per_cluster[cid].item() >= 1 # Mean position pos_mean = g.ndata["pos_hits_xyz"][mask_c].mean(dim=0).cpu().numpy() rows.append({ "cluster_id": int(cid), "energy_sum_hits": round(e_sum, 4), "p_track": round(p_sum, 4) if has_track else 0.0, "n_hits": n_h, "is_charged": has_track, "mean_x": round(float(pos_mean[0]), 2), "mean_y": round(float(pos_mean[1]), 2), "mean_z": round(float(pos_mean[2]), 2), }) particles_df = pd.DataFrame(rows) # 6. If energy correction is available, run it if args.correction and hasattr(model, "energy_correction"): try: particles_df = _run_energy_correction( model, g, x_cluster_coord, beta, labels, y_data, particles_df, dev ) except Exception as e: # Attach a note but don't crash – the basic table is still useful particles_df["note"] = f"Energy correction failed: {e}" return particles_df, hit_cluster_df, mc_particles_df, pandora_particles_df def _extract_mc_particles(event): """Build a DataFrame of MC truth particles from the event's ``X_gen``.""" x_gen = np.asarray(event.get("X_gen", [])) if x_gen.ndim != 2 or x_gen.shape[0] == 0 or x_gen.shape[1] < 18: return pd.DataFrame() rows = [] for i in range(x_gen.shape[0]): pid_raw = int(x_gen[i, 0]) rows.append({ "particle_idx": i, "pid": pid_raw, "pdg_name": _ABS_PDG_NAME.get(abs(pid_raw), str(pid_raw)), "gen_status": int(x_gen[i, 1]), "energy": round(float(x_gen[i, 8]), 4), "momentum": round(float(x_gen[i, 11]), 4), "px": round(float(x_gen[i, 12]), 4), "py": round(float(x_gen[i, 13]), 4), "pz": round(float(x_gen[i, 14]), 4), "mass": round(float(x_gen[i, 10]), 4), "theta": round(float(x_gen[i, 4]), 4), "phi": round(float(x_gen[i, 5]), 4), "vx": round(float(x_gen[i, 15]), 4), "vy": round(float(x_gen[i, 16]), 4), "vz": round(float(x_gen[i, 17]), 4), }) return pd.DataFrame(rows) def _extract_pandora_particles(event): """Build a DataFrame of Pandora PFO particles from the event's ``X_pandora``. ``X_pandora`` columns (per PFO): 0: pid (PDG ID) 1–3: px, py, pz (momentum components at reference point) 4–6: ref_x, ref_y, ref_z (reference point) 7: energy 8: momentum magnitude Returns (pandora_particles_df, pfo_hit_links, pfo_track_links) where *pfo_hit_links* and *pfo_track_links* are integer arrays mapping each hit/track to a PFO index (0-based, -1 = unassigned). """ x_pandora = np.asarray(event.get("X_pandora", [])) pfo_calohit = np.asarray(event.get("pfo_calohit", []), dtype=np.int64) pfo_track = np.asarray(event.get("pfo_track", []), dtype=np.int64) if x_pandora.ndim != 2 or x_pandora.shape[0] == 0 or x_pandora.shape[1] < 9: return pd.DataFrame(), pfo_calohit, pfo_track rows = [] for i in range(x_pandora.shape[0]): pid_raw = int(x_pandora[i, 0]) rows.append({ "pfo_idx": i, "pid": pid_raw, "pdg_name": _ABS_PDG_NAME.get(abs(pid_raw), str(pid_raw)), "energy": round(float(x_pandora[i, 7]), 4), "momentum": round(float(x_pandora[i, 8]), 4), "px": round(float(x_pandora[i, 1]), 4), "py": round(float(x_pandora[i, 2]), 4), "pz": round(float(x_pandora[i, 3]), 4), "ref_x": round(float(x_pandora[i, 4]), 2), "ref_y": round(float(x_pandora[i, 5]), 2), "ref_z": round(float(x_pandora[i, 6]), 2), }) return pd.DataFrame(rows), pfo_calohit, pfo_track def _run_energy_correction(model, g, x_cluster_coord, beta, labels, y_data, particles_df, dev): """Run the energy correction & PID branch and enrich *particles_df*.""" from src.layers.shower_matching import match_showers, obtain_intersection_matrix, obtain_union_matrix from torch_scatter import scatter_add, scatter_mean from src.utils.post_clustering_features import ( get_post_clustering_features, get_extra_features, calculate_eta, calculate_phi, ) x = torch.cat((x_cluster_coord, beta.view(-1, 1)), dim=1) # Re-create per-cluster sub-graphs expected by the correction pipeline particle_ids = torch.unique(g.ndata["particle_number"]) shower_p_unique = torch.unique(labels) model_output_dummy = x # used only for device by match_showers shower_p_unique_m, row_ind, col_ind, i_m_w, _ = match_showers( labels, {"graph": g, "part_true": y_data}, particle_ids, model_output_dummy, 0, 0, None, ) row_ind = torch.Tensor(row_ind).to(dev).long() col_ind = torch.Tensor(col_ind).to(dev).long() if torch.sum(particle_ids == 0) > 0: row_ind_ = row_ind - 1 else: row_ind_ = row_ind index_matches = (col_ind + 1).to(dev).long() # Build per-cluster sub-graphs (matched + fakes) graphs_matched = [] true_energies = [] reco_energies = [] pids_matched = [] coords_matched = [] e_true_daughters = [] for j, sh_label in enumerate(index_matches): if torch.sum(sh_label == index_matches) == 1: mask = labels == sh_label sg = dgl.graph(([], [])) sg.add_nodes(int(mask.sum())) sg = sg.to(dev) sg.ndata["h"] = g.ndata["h"][mask] if "pos_pxpypz" in g.ndata: sg.ndata["pos_pxpypz"] = g.ndata["pos_pxpypz"][mask] if "pos_pxpypz_at_vertex" in g.ndata: sg.ndata["pos_pxpypz_at_vertex"] = g.ndata["pos_pxpypz_at_vertex"][mask] sg.ndata["chi_squared_tracks"] = g.ndata["chi_squared_tracks"][mask] energy_t = y_data.E.to(dev) true_e = energy_t[row_ind_[j]] pids_matched.append(y_data.pid[row_ind_[j]].item()) coords_matched.append(y_data.coord[row_ind_[j]].detach().cpu().numpy()) e_true_daughters.append(y_data.m[row_ind_[j]].to(dev)) reco_e = torch.sum(g.ndata["e_hits"].view(-1).to(dev)[mask]) graphs_matched.append(sg) true_energies.append(true_e.view(-1)) reco_energies.append(reco_e.view(-1)) # Add fakes pred_showers = shower_p_unique_m.clone() pred_showers[index_matches] = -1 pred_showers[0] = -1 fakes_mask = pred_showers != -1 fakes_idx = torch.where(fakes_mask)[0] graphs_fakes = [] reco_fakes = [] for fi in fakes_idx: mask = labels == fi sg = dgl.graph(([], [])) sg.add_nodes(int(mask.sum())) sg = sg.to(dev) sg.ndata["h"] = g.ndata["h"][mask] if "pos_pxpypz" in g.ndata: sg.ndata["pos_pxpypz"] = g.ndata["pos_pxpypz"][mask] if "pos_pxpypz_at_vertex" in g.ndata: sg.ndata["pos_pxpypz_at_vertex"] = g.ndata["pos_pxpypz_at_vertex"][mask] sg.ndata["chi_squared_tracks"] = g.ndata["chi_squared_tracks"][mask] graphs_fakes.append(sg) reco_fakes.append(torch.sum(g.ndata["e_hits"].view(-1).to(dev)[mask]).view(-1)) if not graphs_matched and not graphs_fakes: return particles_df all_graphs = dgl.batch(graphs_matched + graphs_fakes) sum_e = torch.cat(reco_energies + reco_fakes, dim=0) # Compute high-level features batch_num_nodes = all_graphs.batch_num_nodes() batch_idx = [] for i, n in enumerate(batch_num_nodes): batch_idx.extend([i] * n) batch_idx = torch.tensor(batch_idx).to(dev) all_graphs.ndata["h"][:, 0:3] = all_graphs.ndata["h"][:, 0:3] / 3300 graphs_sum_features = scatter_add(all_graphs.ndata["h"], batch_idx, dim=0) graphs_sum_features = graphs_sum_features[batch_idx] betas = torch.sigmoid(all_graphs.ndata["h"][:, -1]) all_graphs.ndata["h"] = torch.cat( (all_graphs.ndata["h"], graphs_sum_features), dim=1 ) high_level = get_post_clustering_features(all_graphs, sum_e) extra_features = get_extra_features(all_graphs, betas) n_clusters = high_level.shape[0] pred_energy = torch.ones(n_clusters, device=dev) pred_pos = torch.ones(n_clusters, 3, device=dev) pred_pid = torch.ones(n_clusters, device=dev).long() node_features_avg = scatter_mean(all_graphs.ndata["h"], batch_idx, dim=0)[:, 0:3] eta = calculate_eta(node_features_avg[:, 0], node_features_avg[:, 1], node_features_avg[:, 2]) phi = calculate_phi(node_features_avg[:, 0], node_features_avg[:, 1]) high_level = torch.cat( (high_level, node_features_avg, eta.view(-1, 1), phi.view(-1, 1)), dim=1 ) num_tracks = high_level[:, 7] charged_idx = torch.where(num_tracks >= 1)[0] neutral_idx = torch.where(num_tracks < 1)[0] def zero_nans(t): out = t.clone() out[out != out] = 0 return out feats_charged = zero_nans(high_level[charged_idx]) feats_neutral = zero_nans(high_level[neutral_idx]) # Run charged prediction charged_energies = model.energy_correction.model_charged.charged_prediction( all_graphs, charged_idx, feats_charged, ) # Run neutral prediction neutral_energies, neutral_pxyz_avg = model.energy_correction.model_neutral.neutral_prediction( all_graphs, neutral_idx, feats_neutral, ) pids_charged = model.energy_correction.pids_charged pids_neutral = model.energy_correction.pids_neutral if len(pids_charged): ch_e, ch_pos, ch_pid_logits, ch_ref = charged_energies else: ch_e, ch_pos, _ = charged_energies ch_pid_logits = None if len(pids_neutral): ne_e, ne_pos, ne_pid_logits, ne_ref = neutral_energies else: ne_e, ne_pos, _ = neutral_energies ne_pid_logits = None pred_energy[charged_idx.flatten()] = ch_e if len(charged_idx) else pred_energy[charged_idx.flatten()] pred_energy[neutral_idx.flatten()] = ne_e if len(neutral_idx) else pred_energy[neutral_idx.flatten()] if ch_pid_logits is not None and len(charged_idx): ch_labels = np.array(pids_charged)[np.argmax(ch_pid_logits.cpu().detach().numpy(), axis=1)] pred_pid[charged_idx.flatten()] = torch.tensor(ch_labels).long().to(dev) if ne_pid_logits is not None and len(neutral_idx): ne_labels = np.array(pids_neutral)[np.argmax(ne_pid_logits.cpu().detach().numpy(), axis=1)] pred_pid[neutral_idx.flatten()] = torch.tensor(ne_labels).long().to(dev) pred_energy[pred_energy < 0] = 0.0 # Direction if len(charged_idx): pred_pos[charged_idx.flatten()] = ch_pos.float().to(dev) if len(neutral_idx): pred_pos[neutral_idx.flatten()] = ne_pos.float().to(dev) # Build enriched output DataFrame n_matched = len(graphs_matched) rows = [] for k in range(n_clusters): is_fake = k >= n_matched pid_cls = int(pred_pid[k].item()) rows.append({ "cluster_id": k + 1, "corrected_energy": round(pred_energy[k].item(), 4), "raw_energy": round(sum_e[k].item(), 4), "pid_class": pid_cls, "pid_label": _PID_LABELS.get(pid_cls, str(pid_cls)), "px": round(pred_pos[k, 0].item(), 4), "py": round(pred_pos[k, 1].item(), 4), "pz": round(pred_pos[k, 2].item(), 4), "is_charged": bool(k in charged_idx), "is_fake": is_fake, }) return pd.DataFrame(rows)