HitPF_demo / src /inference.py
github-actions[bot]
Sync from GitHub f6dbbfb
cc0720f
"""
Standalone single-event MLPF inference.
Provides :func:`run_single_event_inference` which takes raw event data
(from a parquet file or as an awkward record) and model checkpoint paths,
runs the full particle-flow pipeline (graph construction → GATr forward
pass → density-peak clustering → energy correction & PID), and returns:
* a ``pandas.DataFrame`` of predicted particles with their properties
* a hit→cluster mapping as a ``pandas.DataFrame``
"""
import argparse
import types
from typing import Optional
import numpy as np
import pandas as pd
import torch
import dgl
import awkward as ak
from src.data.fileio import _read_parquet
from src.dataset.functions_graph import create_graph
from src.dataset.functions_particles import Particles_GT, add_batch_number
from src.layers.clustering import DPC_custom_CLD, remove_bad_tracks_from_cluster
from src.utils.pid_conversion import pid_conversion_dict
# -- CPU-compatible attention patch ------------------------------------------
def _patch_gatr_attention_for_cpu():
"""Replace GATr's xformers-based attention with a naive implementation.
``xformers.ops.fmha.memory_efficient_attention`` has no CPU kernel, so
running GATr on CPU crashes. This function monkey-patches
``gatr.primitives.attention.scaled_dot_product_attention`` with a plain
PyTorch implementation that works on any device (albeit slower on GPU).
The patch is applied at most once.
"""
import gatr.primitives.attention as _gatr_attn
if getattr(_gatr_attn, "_cpu_patched", False):
return
def _cpu_sdpa(q, k, v, attn_mask=None):
# q, k, v: (B, H, N, D) — batch, heads, items, dim
B, H, N, D = q.shape
scale = float(D) ** -0.5
q2 = q.reshape(B * H, N, D)
k2 = k.reshape(B * H, N, D)
v2 = v.reshape(B * H, N, D)
attn = torch.bmm(q2 * scale, k2.transpose(1, 2)) # (B*H, N, N)
if attn_mask is not None:
dense = _block_diag_mask_to_dense(attn_mask, N, q.device)
if dense is not None:
attn = attn.masked_fill(~dense.unsqueeze(0), float("-inf"))
attn = torch.softmax(attn, dim=-1)
# Rows that are fully masked produce NaN after softmax; zero them out.
attn = attn.nan_to_num(0.0)
out = torch.bmm(attn, v2) # (B*H, N, D)
return out.reshape(B, H, N, D)
_gatr_attn.scaled_dot_product_attention = _cpu_sdpa
_gatr_attn._cpu_patched = True
def _block_diag_mask_to_dense(attn_mask, total_len, device):
"""Convert an ``xformers.ops.fmha.BlockDiagonalMask`` to a dense bool mask."""
try:
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
if not isinstance(attn_mask, BlockDiagonalMask):
return None
except ImportError:
return None
# Extract per-sequence start offsets
try:
seqstarts = attn_mask.q_seqinfo.seqstart_py
except AttributeError:
try:
seqstarts = attn_mask.q_seqinfo.seqstart.cpu().tolist()
except Exception:
return None
mask = torch.zeros(total_len, total_len, dtype=torch.bool, device=device)
for i in range(len(seqstarts) - 1):
s, e = seqstarts[i], seqstarts[i + 1]
mask[s:e, s:e] = True
return mask
# -- PID label → human-readable name ----------------------------------------
_PID_LABELS = {
0: "electron",
1: "charged hadron",
2: "neutral hadron",
3: "photon",
4: "muon",
}
_ABS_PDG_NAME = {
11: "electron",
13: "muon",
22: "photon",
130: "K_L",
211: "pion±",
321: "kaon±",
2112: "neutron",
2212: "proton",
310: "K_S",
}
# -- Minimal args namespace for inference ------------------------------------
def _default_args(**overrides):
"""Return a minimal ``argparse.Namespace`` with defaults the model expects."""
d = dict(
correction=True,
freeze_clustering=True,
predict=True,
pandora=False,
use_gt_clusters=False,
use_average_cc_pos=0.99,
qmin=1.0,
data_config="config_files/config_hits_track_v4.yaml",
network_config="src/models/wrapper/example_mode_gatr_noise.py",
model_prefix="/tmp/mlpf_eval",
start_lr=1e-3,
frac_cluster_loss=0,
local_rank=0,
gpus="0",
batch_size=1,
num_workers=0,
prefetch_factor=1,
num_epochs=1,
steps_per_epoch=None,
samples_per_epoch=None,
steps_per_epoch_val=None,
samples_per_epoch_val=None,
train_val_split=0.8,
data_train=[],
data_val=[],
data_test=[],
data_fraction=1,
file_fraction=1,
fetch_by_files=True,
fetch_step=1,
log_wandb=False,
wandb_displayname="",
wandb_projectname="",
wandb_entity="",
name_output="gradio",
train_batches=100,
)
d.update(overrides)
return argparse.Namespace(**d)
# -- Model loading -----------------------------------------------------------
def load_model(
clustering_ckpt: str,
energy_pid_ckpt: Optional[str] = None,
device: str = "cpu",
args_overrides: Optional[dict] = None,
):
"""Load the full MLPF model (clustering + optional energy/PID correction).
Parameters
----------
clustering_ckpt : str
Path to the clustering checkpoint (``.ckpt``).
energy_pid_ckpt : str or None
Path to the energy-correction / PID checkpoint (``.ckpt``).
If *None*, only clustering is performed (no energy correction / PID).
device : str
``"cpu"`` or ``"cuda:0"`` etc.
args_overrides : dict or None
Extra key-value pairs forwarded to :func:`_default_args`.
Returns
-------
model : ExampleWrapper
The model in eval mode, on *device*.
args : argparse.Namespace
The arguments namespace used.
"""
from src.models.Gatr_pf_e_noise import ExampleWrapper
overrides = dict(args_overrides or {})
has_correction = energy_pid_ckpt is not None
overrides["correction"] = has_correction
args = _default_args(**overrides)
dev = torch.device(device)
if has_correction:
ckpt = torch.load(energy_pid_ckpt, map_location=dev)
state_dict = ckpt["state_dict"]
model = ExampleWrapper(args=args, dev=0)
model.load_state_dict(state_dict, strict=False)
# Overwrite clustering layers from clustering checkpoint
model2 = ExampleWrapper.load_from_checkpoint(
clustering_ckpt, args=args, dev=0, strict=False, map_location=dev,
)
model.gatr = model2.gatr
model.ScaledGooeyBatchNorm2_1 = model2.ScaledGooeyBatchNorm2_1
model.clustering = model2.clustering
model.beta = model2.beta
else:
model = ExampleWrapper.load_from_checkpoint(
clustering_ckpt, args=args, dev=0, strict=False, map_location=dev,
)
model = model.to(dev)
model.eval()
return model, args
def load_random_model(
device: str = "cpu",
args_overrides: Optional[dict] = None,
):
"""Create a GATr model with randomly initialised weights (no checkpoint).
This is useful for debugging to verify that checkpoint weights are
actually being loaded and used by the model.
Parameters
----------
device : str
``"cpu"`` or ``"cuda:0"`` etc.
args_overrides : dict or None
Extra key-value pairs forwarded to :func:`_default_args`.
Returns
-------
model : ExampleWrapper
The model (random weights) in eval mode, on *device*.
args : argparse.Namespace
The arguments namespace used.
"""
from src.models.Gatr_pf_e_noise import ExampleWrapper
overrides = dict(args_overrides or {})
overrides["correction"] = False
args = _default_args(**overrides)
dev = torch.device(device)
model = ExampleWrapper(args=args, dev=0)
model = model.to(dev)
model.eval()
return model, args
# -- Single-event data loading -----------------------------------------------
def load_event_from_parquet(parquet_path: str, event_index: int = 0):
"""Read a single event from a parquet file.
Returns an awkward record with fields ``X_hit``, ``X_track``, ``X_gen``,
``ygen_hit``, ``ygen_track``, etc.
"""
table = _read_parquet(parquet_path)
n_events = len(table["X_track"])
if event_index >= n_events:
raise IndexError(
f"event_index {event_index} out of range (file has {n_events} events)"
)
event = {field: table[field][event_index] for field in table.fields}
return event
# -- Core inference function --------------------------------------------------
@torch.no_grad()
def run_single_event_inference(
event,
model,
args,
device: str = "cpu",
):
"""Run full MLPF inference on a single event.
Parameters
----------
event : dict-like
A single event record (from :func:`load_event_from_parquet`).
model : ExampleWrapper
The loaded model (from :func:`load_model`).
args : argparse.Namespace
The arguments namespace (from :func:`load_model`).
device : str
Device string.
Returns
-------
particles_df : pandas.DataFrame
One row per predicted particle with columns:
``cluster_id``, ``energy``, ``pid_class``, ``pid_label``,
``px``, ``py``, ``pz``, ``is_charged``.
hit_cluster_df : pandas.DataFrame
One row per hit with columns:
``hit_index``, ``cluster_id``, ``pandora_cluster_id``,
``hit_type_id``, ``hit_type``, ``x``, ``y``, ``z``,
``hit_energy``, ``cluster_x``, ``cluster_y``, ``cluster_z``.
``pandora_cluster_id`` is -1 when pandora data is not available
or when the hit has no matching entry (e.g. CSV was modified after
loading from parquet).
mc_particles_df : pandas.DataFrame
One row per MC truth particle with columns:
``pid``, ``energy``, ``momentum``, ``px``, ``py``, ``pz``,
``mass``, ``theta``, ``phi``, ``vx``, ``vy``, ``vz``,
``gen_status``, ``pdg_name``.
pandora_particles_df : pandas.DataFrame
One row per Pandora PFO with columns:
``pfo_idx``, ``pid``, ``pdg_name``, ``energy``, ``momentum``,
``px``, ``py``, ``pz``, ``ref_x``, ``ref_y``, ``ref_z``.
Empty when pandora data is not available in the input.
"""
dev = torch.device(device)
# Ensure eval mode so that BatchNorm layers use running statistics from
# training instead of computing batch statistics from the current
# (single-event) input. Without this, inference with batch_size=1
# produces incorrect normalization.
model.eval()
if dev.type == "cpu":
_patch_gatr_attention_for_cpu()
# 0. Extract MC truth particles table and pandora particles
mc_particles_df = _extract_mc_particles(event)
pandora_particles_df, pfo_calohit, pfo_track = _extract_pandora_particles(event)
# 1. Build DGL graph from the event
[g, y_data], graph_empty = create_graph(event, for_training=False, args=args)
if graph_empty:
return pd.DataFrame(), pd.DataFrame(), mc_particles_df, pandora_particles_df
g = g.to(dev)
# Prepare batch metadata expected by the model
y_data.batch_number = torch.zeros(y_data.E.shape[0], 1)
# 2. Forward pass through the GATr clustering backbone
inputs = g.ndata["pos_hits_xyz"].float().to(dev)
inputs_scalar = g.ndata["hit_type"].float().view(-1, 1).to(dev)
from gatr.interface import embed_point, embed_scalar
from xformers.ops.fmha import BlockDiagonalMask
inputs_normed = model.ScaledGooeyBatchNorm2_1(inputs)
embedded_inputs = embed_point(inputs_normed) + embed_scalar(inputs_scalar)
embedded_inputs = embedded_inputs.unsqueeze(-2)
mask = BlockDiagonalMask.from_seqlens([g.num_nodes()])
scalars = torch.cat(
(g.ndata["e_hits"].float().to(dev), g.ndata["p_hits"].float().to(dev)), dim=1
)
from gatr.interface import extract_point, extract_scalar
embedded_outputs, scalar_outputs = model.gatr(
embedded_inputs, scalars=scalars, attention_mask=mask
)
points = extract_point(embedded_outputs[:, 0, :])
nodewise_outputs = extract_scalar(embedded_outputs)
x_point = points
x_scalar = torch.cat(
(nodewise_outputs.view(-1, 1), scalar_outputs.view(-1, 1)), dim=1
)
x_cluster_coord = model.clustering(x_point)
beta = model.beta(x_scalar)
g.ndata["final_cluster"] = x_cluster_coord
g.ndata["beta"] = beta.view(-1)
# 3. Density-peak clustering
labels = DPC_custom_CLD(x_cluster_coord, g, dev)
labels, _ = remove_bad_tracks_from_cluster(g, labels)
# 4. Build hit→cluster table
n_hits = g.num_nodes()
hit_types_raw = g.ndata["hit_type"].cpu().numpy()
hit_type_names = {1: "track", 2: "ECAL", 3: "HCAL", 4: "muon"}
# Build pandora cluster ID per node (hits first, then tracks)
# Use min of array lengths for graceful handling when CSV was modified
n_calo = len(np.asarray(event.get("X_hit", [])))
pandora_cluster_ids = np.full(n_hits, -1, dtype=np.int64)
if len(pfo_calohit) > 0:
n_assign = min(len(pfo_calohit), n_calo)
pandora_cluster_ids[:n_assign] = pfo_calohit[:n_assign]
n_tracks = n_hits - n_calo
if n_tracks > 0 and len(pfo_track) > 0:
n_assign = min(len(pfo_track), n_tracks)
pandora_cluster_ids[n_calo:n_calo + n_assign] = pfo_track[:n_assign]
hit_cluster_df = pd.DataFrame({
"hit_index": np.arange(n_hits),
"cluster_id": labels.cpu().numpy(),
"pandora_cluster_id": pandora_cluster_ids,
"hit_type_id": hit_types_raw,
"hit_type": [hit_type_names.get(int(t), str(int(t))) for t in hit_types_raw],
"x": g.ndata["pos_hits_xyz"][:, 0].cpu().numpy(),
"y": g.ndata["pos_hits_xyz"][:, 1].cpu().numpy(),
"z": g.ndata["pos_hits_xyz"][:, 2].cpu().numpy(),
"hit_energy": g.ndata["e_hits"].view(-1).cpu().numpy(),
"cluster_x": x_cluster_coord[:, 0].cpu().numpy(),
"cluster_y": x_cluster_coord[:, 1].cpu().numpy(),
"cluster_z": x_cluster_coord[:, 2].cpu().numpy(),
})
# 5. Per-cluster summary (basic, before energy correction)
unique_labels = torch.unique(labels)
# cluster 0 = noise
cluster_ids = unique_labels[unique_labels > 0].cpu().numpy()
from torch_scatter import scatter_add
e_per_cluster = scatter_add(
g.ndata["e_hits"].view(-1).to(dev), labels.to(dev)
)
p_per_cluster = scatter_add(
g.ndata["p_hits"].view(-1).to(dev), labels.to(dev)
)
n_hits_per_cluster = scatter_add(
torch.ones(n_hits, device=dev), labels.to(dev)
)
# Check if any cluster has a track (→ charged)
is_track_per_cluster = scatter_add(
(g.ndata["hit_type"].to(dev) == 1).float(), labels.to(dev)
)
rows = []
for cid in cluster_ids:
mask_c = labels == cid
e_sum = e_per_cluster[cid].item()
p_sum = p_per_cluster[cid].item()
n_h = int(n_hits_per_cluster[cid].item())
has_track = is_track_per_cluster[cid].item() >= 1
# Mean position
pos_mean = g.ndata["pos_hits_xyz"][mask_c].mean(dim=0).cpu().numpy()
rows.append({
"cluster_id": int(cid),
"energy_sum_hits": round(e_sum, 4),
"p_track": round(p_sum, 4) if has_track else 0.0,
"n_hits": n_h,
"is_charged": has_track,
"mean_x": round(float(pos_mean[0]), 2),
"mean_y": round(float(pos_mean[1]), 2),
"mean_z": round(float(pos_mean[2]), 2),
})
particles_df = pd.DataFrame(rows)
# 6. If energy correction is available, run it
if args.correction and hasattr(model, "energy_correction"):
try:
particles_df = _run_energy_correction(
model, g, x_cluster_coord, beta, labels, y_data, particles_df, dev
)
except Exception as e:
# Attach a note but don't crash – the basic table is still useful
particles_df["note"] = f"Energy correction failed: {e}"
return particles_df, hit_cluster_df, mc_particles_df, pandora_particles_df
def _extract_mc_particles(event):
"""Build a DataFrame of MC truth particles from the event's ``X_gen``."""
x_gen = np.asarray(event.get("X_gen", []))
if x_gen.ndim != 2 or x_gen.shape[0] == 0 or x_gen.shape[1] < 18:
return pd.DataFrame()
rows = []
for i in range(x_gen.shape[0]):
pid_raw = int(x_gen[i, 0])
rows.append({
"particle_idx": i,
"pid": pid_raw,
"pdg_name": _ABS_PDG_NAME.get(abs(pid_raw), str(pid_raw)),
"gen_status": int(x_gen[i, 1]),
"energy": round(float(x_gen[i, 8]), 4),
"momentum": round(float(x_gen[i, 11]), 4),
"px": round(float(x_gen[i, 12]), 4),
"py": round(float(x_gen[i, 13]), 4),
"pz": round(float(x_gen[i, 14]), 4),
"mass": round(float(x_gen[i, 10]), 4),
"theta": round(float(x_gen[i, 4]), 4),
"phi": round(float(x_gen[i, 5]), 4),
"vx": round(float(x_gen[i, 15]), 4),
"vy": round(float(x_gen[i, 16]), 4),
"vz": round(float(x_gen[i, 17]), 4),
})
return pd.DataFrame(rows)
def _extract_pandora_particles(event):
"""Build a DataFrame of Pandora PFO particles from the event's ``X_pandora``.
``X_pandora`` columns (per PFO):
0: pid (PDG ID)
1–3: px, py, pz (momentum components at reference point)
4–6: ref_x, ref_y, ref_z (reference point)
7: energy
8: momentum magnitude
Returns (pandora_particles_df, pfo_hit_links, pfo_track_links) where
*pfo_hit_links* and *pfo_track_links* are integer arrays mapping each
hit/track to a PFO index (0-based, -1 = unassigned).
"""
x_pandora = np.asarray(event.get("X_pandora", []))
pfo_calohit = np.asarray(event.get("pfo_calohit", []), dtype=np.int64)
pfo_track = np.asarray(event.get("pfo_track", []), dtype=np.int64)
if x_pandora.ndim != 2 or x_pandora.shape[0] == 0 or x_pandora.shape[1] < 9:
return pd.DataFrame(), pfo_calohit, pfo_track
rows = []
for i in range(x_pandora.shape[0]):
pid_raw = int(x_pandora[i, 0])
rows.append({
"pfo_idx": i,
"pid": pid_raw,
"pdg_name": _ABS_PDG_NAME.get(abs(pid_raw), str(pid_raw)),
"energy": round(float(x_pandora[i, 7]), 4),
"momentum": round(float(x_pandora[i, 8]), 4),
"px": round(float(x_pandora[i, 1]), 4),
"py": round(float(x_pandora[i, 2]), 4),
"pz": round(float(x_pandora[i, 3]), 4),
"ref_x": round(float(x_pandora[i, 4]), 2),
"ref_y": round(float(x_pandora[i, 5]), 2),
"ref_z": round(float(x_pandora[i, 6]), 2),
})
return pd.DataFrame(rows), pfo_calohit, pfo_track
def _run_energy_correction(model, g, x_cluster_coord, beta, labels, y_data, particles_df, dev):
"""Run the energy correction & PID branch and enrich *particles_df*."""
from src.layers.shower_matching import match_showers, obtain_intersection_matrix, obtain_union_matrix
from torch_scatter import scatter_add, scatter_mean
from src.utils.post_clustering_features import (
get_post_clustering_features, get_extra_features, calculate_eta, calculate_phi,
)
x = torch.cat((x_cluster_coord, beta.view(-1, 1)), dim=1)
# Re-create per-cluster sub-graphs expected by the correction pipeline
particle_ids = torch.unique(g.ndata["particle_number"])
shower_p_unique = torch.unique(labels)
model_output_dummy = x # used only for device by match_showers
shower_p_unique_m, row_ind, col_ind, i_m_w, _ = match_showers(
labels, {"graph": g, "part_true": y_data},
particle_ids, model_output_dummy, 0, 0, None,
)
row_ind = torch.Tensor(row_ind).to(dev).long()
col_ind = torch.Tensor(col_ind).to(dev).long()
if torch.sum(particle_ids == 0) > 0:
row_ind_ = row_ind - 1
else:
row_ind_ = row_ind
index_matches = (col_ind + 1).to(dev).long()
# Build per-cluster sub-graphs (matched + fakes)
graphs_matched = []
true_energies = []
reco_energies = []
pids_matched = []
coords_matched = []
e_true_daughters = []
for j, sh_label in enumerate(index_matches):
if torch.sum(sh_label == index_matches) == 1:
mask = labels == sh_label
sg = dgl.graph(([], []))
sg.add_nodes(int(mask.sum()))
sg = sg.to(dev)
sg.ndata["h"] = g.ndata["h"][mask]
if "pos_pxpypz" in g.ndata:
sg.ndata["pos_pxpypz"] = g.ndata["pos_pxpypz"][mask]
if "pos_pxpypz_at_vertex" in g.ndata:
sg.ndata["pos_pxpypz_at_vertex"] = g.ndata["pos_pxpypz_at_vertex"][mask]
sg.ndata["chi_squared_tracks"] = g.ndata["chi_squared_tracks"][mask]
energy_t = y_data.E.to(dev)
true_e = energy_t[row_ind_[j]]
pids_matched.append(y_data.pid[row_ind_[j]].item())
coords_matched.append(y_data.coord[row_ind_[j]].detach().cpu().numpy())
e_true_daughters.append(y_data.m[row_ind_[j]].to(dev))
reco_e = torch.sum(g.ndata["e_hits"].view(-1).to(dev)[mask])
graphs_matched.append(sg)
true_energies.append(true_e.view(-1))
reco_energies.append(reco_e.view(-1))
# Add fakes
pred_showers = shower_p_unique_m.clone()
pred_showers[index_matches] = -1
pred_showers[0] = -1
fakes_mask = pred_showers != -1
fakes_idx = torch.where(fakes_mask)[0]
graphs_fakes = []
reco_fakes = []
for fi in fakes_idx:
mask = labels == fi
sg = dgl.graph(([], []))
sg.add_nodes(int(mask.sum()))
sg = sg.to(dev)
sg.ndata["h"] = g.ndata["h"][mask]
if "pos_pxpypz" in g.ndata:
sg.ndata["pos_pxpypz"] = g.ndata["pos_pxpypz"][mask]
if "pos_pxpypz_at_vertex" in g.ndata:
sg.ndata["pos_pxpypz_at_vertex"] = g.ndata["pos_pxpypz_at_vertex"][mask]
sg.ndata["chi_squared_tracks"] = g.ndata["chi_squared_tracks"][mask]
graphs_fakes.append(sg)
reco_fakes.append(torch.sum(g.ndata["e_hits"].view(-1).to(dev)[mask]).view(-1))
if not graphs_matched and not graphs_fakes:
return particles_df
all_graphs = dgl.batch(graphs_matched + graphs_fakes)
sum_e = torch.cat(reco_energies + reco_fakes, dim=0)
# Compute high-level features
batch_num_nodes = all_graphs.batch_num_nodes()
batch_idx = []
for i, n in enumerate(batch_num_nodes):
batch_idx.extend([i] * n)
batch_idx = torch.tensor(batch_idx).to(dev)
all_graphs.ndata["h"][:, 0:3] = all_graphs.ndata["h"][:, 0:3] / 3300
graphs_sum_features = scatter_add(all_graphs.ndata["h"], batch_idx, dim=0)
graphs_sum_features = graphs_sum_features[batch_idx]
betas = torch.sigmoid(all_graphs.ndata["h"][:, -1])
all_graphs.ndata["h"] = torch.cat(
(all_graphs.ndata["h"], graphs_sum_features), dim=1
)
high_level = get_post_clustering_features(all_graphs, sum_e)
extra_features = get_extra_features(all_graphs, betas)
n_clusters = high_level.shape[0]
pred_energy = torch.ones(n_clusters, device=dev)
pred_pos = torch.ones(n_clusters, 3, device=dev)
pred_pid = torch.ones(n_clusters, device=dev).long()
node_features_avg = scatter_mean(all_graphs.ndata["h"], batch_idx, dim=0)[:, 0:3]
eta = calculate_eta(node_features_avg[:, 0], node_features_avg[:, 1], node_features_avg[:, 2])
phi = calculate_phi(node_features_avg[:, 0], node_features_avg[:, 1])
high_level = torch.cat(
(high_level, node_features_avg, eta.view(-1, 1), phi.view(-1, 1)), dim=1
)
num_tracks = high_level[:, 7]
charged_idx = torch.where(num_tracks >= 1)[0]
neutral_idx = torch.where(num_tracks < 1)[0]
def zero_nans(t):
out = t.clone()
out[out != out] = 0
return out
feats_charged = zero_nans(high_level[charged_idx])
feats_neutral = zero_nans(high_level[neutral_idx])
# Run charged prediction
charged_energies = model.energy_correction.model_charged.charged_prediction(
all_graphs, charged_idx, feats_charged,
)
# Run neutral prediction
neutral_energies, neutral_pxyz_avg = model.energy_correction.model_neutral.neutral_prediction(
all_graphs, neutral_idx, feats_neutral,
)
pids_charged = model.energy_correction.pids_charged
pids_neutral = model.energy_correction.pids_neutral
if len(pids_charged):
ch_e, ch_pos, ch_pid_logits, ch_ref = charged_energies
else:
ch_e, ch_pos, _ = charged_energies
ch_pid_logits = None
if len(pids_neutral):
ne_e, ne_pos, ne_pid_logits, ne_ref = neutral_energies
else:
ne_e, ne_pos, _ = neutral_energies
ne_pid_logits = None
pred_energy[charged_idx.flatten()] = ch_e if len(charged_idx) else pred_energy[charged_idx.flatten()]
pred_energy[neutral_idx.flatten()] = ne_e if len(neutral_idx) else pred_energy[neutral_idx.flatten()]
if ch_pid_logits is not None and len(charged_idx):
ch_labels = np.array(pids_charged)[np.argmax(ch_pid_logits.cpu().detach().numpy(), axis=1)]
pred_pid[charged_idx.flatten()] = torch.tensor(ch_labels).long().to(dev)
if ne_pid_logits is not None and len(neutral_idx):
ne_labels = np.array(pids_neutral)[np.argmax(ne_pid_logits.cpu().detach().numpy(), axis=1)]
pred_pid[neutral_idx.flatten()] = torch.tensor(ne_labels).long().to(dev)
pred_energy[pred_energy < 0] = 0.0
# Direction
if len(charged_idx):
pred_pos[charged_idx.flatten()] = ch_pos.float().to(dev)
if len(neutral_idx):
pred_pos[neutral_idx.flatten()] = ne_pos.float().to(dev)
# Build enriched output DataFrame
n_matched = len(graphs_matched)
rows = []
for k in range(n_clusters):
is_fake = k >= n_matched
pid_cls = int(pred_pid[k].item())
rows.append({
"cluster_id": k + 1,
"corrected_energy": round(pred_energy[k].item(), 4),
"raw_energy": round(sum_e[k].item(), 4),
"pid_class": pid_cls,
"pid_label": _PID_LABELS.get(pid_cls, str(pid_cls)),
"px": round(pred_pos[k, 0].item(), 4),
"py": round(pred_pos[k, 1].item(), 4),
"pz": round(pred_pos[k, 2].item(), 4),
"is_charged": bool(k in charged_idx),
"is_fake": is_fake,
})
return pd.DataFrame(rows)