HitPF_demo / app.py
github-actions[bot]
Sync from GitHub f6dbbfb
cc0720f
#!/usr/bin/env python
"""
Gradio UI for single-event MLPF inference.
Launch with:
python app.py [--device cpu]
The UI lets you:
1. Load an event from a parquet file (pick file + event index), **or**
paste hit / track / particle data in CSV format.
2. (Optionally) load pre-trained model checkpoints.
3. Run inference → view predicted particles and the hit→cluster mapping.
"""
import argparse
import os
import shutil
import traceback
import gradio as gr
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from huggingface_hub import hf_hub_download
# ---------------------------------------------------------------------------
# Auto-download demo files from Hugging Face Hub if they are not present
# ---------------------------------------------------------------------------
_HF_REPO_ID = "gregorkrzmanc/hitpf_demo_files"
_DEMO_FILES = [
"model_clustering.ckpt",
"model_e_pid.ckpt",
"test_data.parquet",
]
def _ensure_demo_files(dest_dir: str = ".") -> None:
"""Download demo files from Hugging Face Hub if they don't already exist."""
for fname in _DEMO_FILES:
dest = os.path.join(dest_dir, fname)
if not os.path.isfile(dest):
try:
print(f"Downloading {fname} from HF Hub ({_HF_REPO_ID}) …")
downloaded = hf_hub_download(
repo_id=_HF_REPO_ID,
filename=fname,
repo_type="dataset",
)
shutil.copy(downloaded, dest)
print(f" → saved to {dest}")
except Exception as exc:
print(f" ⚠️ Could not download {fname}: {exc}")
_ensure_demo_files()
# ---------------------------------------------------------------------------
# Global state – filled lazily
# ---------------------------------------------------------------------------
_MODEL = None
_ARGS = None
_DEVICE = "cpu"
def _set_device(device: str):
global _DEVICE
_DEVICE = device
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def load_model_ui(clustering_ckpt: str, energy_pid_ckpt: str, device: str):
"""Load model from checkpoint paths (called by the UI button)."""
global _MODEL, _ARGS, _DEVICE
_DEVICE = device or "cpu"
if not clustering_ckpt or not os.path.isfile(clustering_ckpt):
return "⚠️ Please provide a valid path to the clustering checkpoint."
energy_pid = energy_pid_ckpt if (energy_pid_ckpt and os.path.isfile(energy_pid_ckpt)) else None
try:
from src.inference import load_model
_MODEL, _ARGS = load_model(
clustering_ckpt=clustering_ckpt,
energy_pid_ckpt=energy_pid,
device=_DEVICE,
)
msg = f"✅ Model loaded on **{_DEVICE}**"
if energy_pid:
msg += " (clustering + energy/PID correction)"
else:
msg += " (clustering only — no energy/PID correction)"
return msg
except Exception:
return f"❌ Failed to load model:\n```\n{traceback.format_exc()}\n```"
# ---------------------------------------------------------------------------
# Event loading helpers
# ---------------------------------------------------------------------------
def _count_events_in_parquet(parquet_path: str) -> str:
"""Return a short info string about the parquet file."""
if not parquet_path or not os.path.isfile(parquet_path):
return "No file selected"
try:
from src.inference import load_event_from_parquet
from src.data.fileio import _read_parquet
table = _read_parquet(parquet_path)
n = len(table["X_track"])
return f"File has **{n}** events (indices 0–{n-1})"
except Exception as e:
return f"Error reading file: {e}"
def _load_event_into_csv(parquet_path: str, event_index: int):
"""Load an event from a parquet file and return CSV strings for the text fields."""
if not parquet_path or not os.path.isfile(parquet_path):
return "", "", "", "", "", "⚠️ Please provide a valid parquet file path."
try:
from src.inference import load_event_from_parquet
event = load_event_from_parquet(parquet_path, int(event_index))
hits_arr = np.asarray(event.get("X_hit", []))
tracks_arr = np.asarray(event.get("X_track", []))
particles_arr = np.asarray(event.get("X_gen", []))
pandora_arr = np.asarray(event.get("X_pandora", []))
def _arr_to_csv(arr):
if arr.ndim != 2:
return ""
return "\n".join(",".join(str(v) for v in row) for row in arr)
def _1d_to_csv(arr):
if len(arr) == 0:
return ""
return ",".join(str(int(v)) for v in arr)
pfo_calohit = np.asarray(event.get("pfo_calohit", []), dtype=np.int64)
pfo_track = np.asarray(event.get("pfo_track", []), dtype=np.int64)
calohit_csv = _1d_to_csv(pfo_calohit)
track_csv = _1d_to_csv(pfo_track)
if calohit_csv and track_csv:
pfo_links_csv = calohit_csv + "\n" + track_csv
elif calohit_csv:
pfo_links_csv = calohit_csv
elif track_csv:
pfo_links_csv = "\n" + track_csv
else:
pfo_links_csv = ""
return (
_arr_to_csv(hits_arr),
_arr_to_csv(tracks_arr),
_arr_to_csv(particles_arr),
_arr_to_csv(pandora_arr),
pfo_links_csv,
f"✅ Loaded event **{int(event_index)}**: "
f"{hits_arr.shape[0] if hits_arr.ndim == 2 else 0} hits, "
f"{tracks_arr.shape[0] if tracks_arr.ndim == 2 else 0} tracks, "
f"{particles_arr.shape[0] if particles_arr.ndim == 2 else 0} MC particles, "
f"{pandora_arr.shape[0] if pandora_arr.ndim == 2 else 0} Pandora PFOs",
)
except Exception as e:
return "", "", "", "", "", f"❌ Error loading event: {e}"
def _build_cluster_plot(hit_cluster_df: pd.DataFrame) -> go.Figure:
"""Build an interactive 3D scatter plot of hits colored by cluster ID."""
if hit_cluster_df.empty:
fig = go.Figure()
fig.update_layout(title="No hit data available", height=600)
return fig
df = hit_cluster_df.copy()
# Drop rows with NaN/Inf coordinates
for col in ("x", "y", "z", "hit_energy"):
df[col] = pd.to_numeric(df[col], errors="coerce")
df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["x", "y", "z", "hit_energy"])
if df.empty:
fig = go.Figure()
fig.update_layout(title="No valid hit data (all NaN/Inf)", height=600)
return fig
# Normalize hit energies for marker sizes
energies = df["hit_energy"].values.astype(float)
e_min, e_max = float(energies.min()), float(energies.max())
if e_max > e_min:
norm_e = (energies - e_min) / (e_max - e_min)
else:
norm_e = np.ones_like(energies) * 0.5 # midpoint when all equal
marker_sizes = 3 + norm_e * 12 # min size 3, max size 15
# Build per-hit hover text (avoids mixed-type customdata serialization issues)
df["_hover"] = (
"<b>" + df["hit_type"].astype(str) + "</b> hit #" + df["hit_index"].astype(int).astype(str) + "<br>"
+ "Cluster: " + df["cluster_id"].astype(int).astype(str) + "<br>"
+ "Energy: " + df["hit_energy"].map(lambda v: f"{v:.4f}") + "<br>"
+ "x: " + df["x"].map(lambda v: f"{v:.2f}")
+ ", y: " + df["y"].map(lambda v: f"{v:.2f}")
+ ", z: " + df["z"].map(lambda v: f"{v:.2f}")
)
cluster_ids = df["cluster_id"].values
unique_clusters = sorted(set(int(c) for c in cluster_ids))
fig = go.Figure()
for cid in unique_clusters:
mask = cluster_ids == cid
subset = df[mask]
sizes = marker_sizes[mask].tolist()
label = "noise" if cid == 0 else f"cluster {cid}"
fig.add_trace(go.Scatter3d(
x=subset["x"].tolist(),
y=subset["y"].tolist(),
z=subset["z"].tolist(),
mode="markers",
name=label,
marker=dict(size=sizes, opacity=0.8),
hovertext=subset["_hover"].tolist(),
hoverinfo="text",
))
fig.update_layout(
title="Hit → Cluster 3D Map",
scene=dict(xaxis_title="x", yaxis_title="y", zaxis_title="z"),
legend_title="Cluster",
height=600,
margin=dict(l=0, r=0, t=40, b=0),
)
return fig
def _build_pandora_cluster_plot(hit_cluster_df: pd.DataFrame) -> go.Figure:
"""Build an interactive 3D scatter plot of hits colored by Pandora cluster ID."""
if hit_cluster_df.empty or "pandora_cluster_id" not in hit_cluster_df.columns:
fig = go.Figure()
fig.update_layout(title="No Pandora cluster data available", height=600)
return fig
df = hit_cluster_df.copy()
# Only keep rows that have valid Pandora assignments (pandora_cluster_id >= 0)
for col in ("x", "y", "z", "hit_energy"):
df[col] = pd.to_numeric(df[col], errors="coerce")
df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["x", "y", "z", "hit_energy"])
if df.empty:
fig = go.Figure()
fig.update_layout(title="No valid hit data for Pandora plot (all NaN/Inf)", height=600)
return fig
# Normalize hit energies for marker sizes
energies = df["hit_energy"].values.astype(float)
e_min, e_max = float(energies.min()), float(energies.max())
if e_max > e_min:
norm_e = (energies - e_min) / (e_max - e_min)
else:
norm_e = np.ones_like(energies) * 0.5
marker_sizes = 3 + norm_e * 12
# Build per-hit hover text
df["_hover"] = (
"<b>" + df["hit_type"].astype(str) + "</b> hit #" + df["hit_index"].astype(int).astype(str) + "<br>"
+ "Pandora cluster: " + df["pandora_cluster_id"].astype(int).astype(str) + "<br>"
+ "Energy: " + df["hit_energy"].map(lambda v: f"{v:.4f}") + "<br>"
+ "x: " + df["x"].map(lambda v: f"{v:.2f}")
+ ", y: " + df["y"].map(lambda v: f"{v:.2f}")
+ ", z: " + df["z"].map(lambda v: f"{v:.2f}")
)
pandora_ids = df["pandora_cluster_id"].values
unique_clusters = sorted(set(int(c) for c in pandora_ids))
fig = go.Figure()
for cid in unique_clusters:
mask = pandora_ids == cid
subset = df[mask]
sizes = marker_sizes[mask].tolist()
label = "unassigned" if cid == -1 else f"PFO {cid}"
fig.add_trace(go.Scatter3d(
x=subset["x"].tolist(),
y=subset["y"].tolist(),
z=subset["z"].tolist(),
mode="markers",
name=label,
marker=dict(size=sizes, opacity=0.8),
hovertext=subset["_hover"].tolist(),
hoverinfo="text",
))
fig.update_layout(
title="Hit → Pandora Cluster 3D Map",
scene=dict(xaxis_title="x", yaxis_title="y", zaxis_title="z"),
legend_title="Pandora PFO",
height=600,
margin=dict(l=0, r=0, t=40, b=0),
)
return fig
def _build_clustering_space_plot(hit_cluster_df: pd.DataFrame) -> go.Figure:
"""Build an interactive 3D scatter plot of hits in the learned clustering space."""
if hit_cluster_df.empty or "cluster_x" not in hit_cluster_df.columns:
fig = go.Figure()
fig.update_layout(title="No clustering-space data available", height=600)
return fig
df = hit_cluster_df.copy()
# Drop rows with NaN/Inf coordinates
for col in ("cluster_x", "cluster_y", "cluster_z", "hit_energy"):
df[col] = pd.to_numeric(df[col], errors="coerce")
df = df.replace([np.inf, -np.inf], np.nan).dropna(
subset=["cluster_x", "cluster_y", "cluster_z", "hit_energy"]
)
if df.empty:
fig = go.Figure()
fig.update_layout(title="No valid clustering-space data (all NaN/Inf)", height=600)
return fig
# Normalize hit energies for marker sizes
energies = df["hit_energy"].values.astype(float)
e_min, e_max = float(energies.min()), float(energies.max())
if e_max > e_min:
norm_e = (energies - e_min) / (e_max - e_min)
else:
norm_e = np.ones_like(energies) * 0.5
marker_sizes = 3 + norm_e * 12
# Build per-hit hover text
df["_hover"] = (
"<b>" + df["hit_type"].astype(str) + "</b> hit #" + df["hit_index"].astype(int).astype(str) + "<br>"
+ "Cluster: " + df["cluster_id"].astype(int).astype(str) + "<br>"
+ "Energy: " + df["hit_energy"].map(lambda v: f"{v:.4f}") + "<br>"
+ "cluster_x: " + df["cluster_x"].map(lambda v: f"{v:.4f}")
+ ", cluster_y: " + df["cluster_y"].map(lambda v: f"{v:.4f}")
+ ", cluster_z: " + df["cluster_z"].map(lambda v: f"{v:.4f}")
)
cluster_ids = df["cluster_id"].values
unique_clusters = sorted(set(int(c) for c in cluster_ids))
fig = go.Figure()
for cid in unique_clusters:
mask = cluster_ids == cid
subset = df[mask]
sizes = marker_sizes[mask].tolist()
label = "noise" if cid == 0 else f"cluster {cid}"
fig.add_trace(go.Scatter3d(
x=subset["cluster_x"].tolist(),
y=subset["cluster_y"].tolist(),
z=subset["cluster_z"].tolist(),
mode="markers",
name=label,
marker=dict(size=sizes, opacity=0.8),
hovertext=subset["_hover"].tolist(),
hoverinfo="text",
))
fig.update_layout(
title="Clustering Space 3D Map (GATr regressed coordinates)",
scene=dict(
xaxis_title="cluster_x",
yaxis_title="cluster_y",
zaxis_title="cluster_z",
),
legend_title="Cluster",
height=600,
margin=dict(l=0, r=0, t=40, b=0),
)
return fig
# ---------------------------------------------------------------------------
# Main inference entry point for the UI
# ---------------------------------------------------------------------------
def _compute_inv_mass(df, e_col, px_col, py_col, pz_col):
"""Compute the invariant mass of a system of particles in GeV.
Returns the scalar invariant mass m = sqrt(max((ΣE)²−(Σpx)²−(Σpy)²−(Σpz)², 0)),
or *None* when *df* is empty or the required columns are absent.
"""
if df.empty:
return None
for col in (e_col, px_col, py_col, pz_col):
if col not in df.columns:
return None
E = float(df[e_col].sum())
px = float(df[px_col].sum())
py = float(df[py_col].sum())
pz = float(df[pz_col].sum())
m2 = E ** 2 - px ** 2 - py ** 2 - pz ** 2
return float(np.sqrt(max(m2, 0.0)))
def _fmt_mass(val):
"""Format an invariant-mass value (float or None) as a GeV string."""
return f"{val:.4f} GeV" if val is not None else "N/A"
def run_inference_ui(
parquet_path: str,
event_index: int,
csv_hits: str,
csv_tracks: str,
csv_particles: str,
csv_pandora: str,
csv_pfo_links: str = "",
):
"""Run inference on a single event, return predicted particles, 3D plots, MC particles and Pandora particles.
Returns
-------
particles_df : pandas.DataFrame
cluster_fig : plotly.graph_objects.Figure
clustering_space_fig : plotly.graph_objects.Figure
pandora_cluster_fig : plotly.graph_objects.Figure
mc_particles_df : pandas.DataFrame
pandora_particles_df : pandas.DataFrame
inv_mass_summary : str
"""
global _MODEL, _ARGS, _DEVICE
empty_fig = go.Figure()
if _MODEL is None:
return (
pd.DataFrame({"error": ["Model not loaded. Please load a model first."]}),
empty_fig,
empty_fig,
empty_fig,
pd.DataFrame(),
pd.DataFrame(),
"",
)
try:
from src.inference import load_event_from_parquet, run_single_event_inference
# Decide input source
use_parquet = parquet_path and os.path.isfile(parquet_path)
use_csv = bool(csv_hits and csv_hits.strip())
if not use_parquet and not use_csv:
return (
pd.DataFrame({"error": ["Provide a parquet file or paste CSV hit data."]}),
empty_fig,
empty_fig,
empty_fig,
pd.DataFrame(),
pd.DataFrame(),
"",
)
if use_csv:
event = _parse_csv_event(csv_hits, csv_tracks, csv_particles, csv_pandora, csv_pfo_links)
elif use_parquet:
event = load_event_from_parquet(parquet_path, int(event_index))
particles_df, hit_cluster_df, mc_particles_df, pandora_particles_df = run_single_event_inference(
event, _MODEL, _ARGS, device=_DEVICE,
)
if particles_df.empty:
particles_df = pd.DataFrame({"info": ["Event produced no clusters (empty graph)."]})
cluster_fig = _build_cluster_plot(hit_cluster_df)
clustering_space_fig = _build_clustering_space_plot(hit_cluster_df)
pandora_cluster_fig = _build_pandora_cluster_plot(hit_cluster_df)
# Compute invariant masses [GeV]
m_true = _compute_inv_mass(mc_particles_df, "energy", "px", "py", "pz")
# HitPF uses corrected_energy when available, otherwise energy_sum_hits
hitpf_e_col = "corrected_energy" if "corrected_energy" in particles_df.columns else "energy_sum_hits"
m_reco_hitpf = _compute_inv_mass(particles_df, hitpf_e_col, "px", "py", "pz")
m_reco_pandora = _compute_inv_mass(pandora_particles_df, "energy", "px", "py", "pz")
inv_mass_summary = (
f"**Invariant mass (sum of all particle 4-vectors)**\n\n"
f"| Algorithm | m [GeV] |\n"
f"|---|---|\n"
f"| m_true (MC truth) | {_fmt_mass(m_true)} |\n"
f"| m_reco (HitPF) | {_fmt_mass(m_reco_hitpf)} |\n"
f"| m_reco (Pandora) | {_fmt_mass(m_reco_pandora)} |"
)
return particles_df, cluster_fig, clustering_space_fig, pandora_cluster_fig, mc_particles_df, pandora_particles_df, inv_mass_summary
except Exception:
err = traceback.format_exc()
return (
pd.DataFrame({"error": [err]}),
empty_fig,
empty_fig,
empty_fig,
pd.DataFrame(),
pd.DataFrame(),
"",
)
def _parse_csv_event(csv_hits: str, csv_tracks: str, csv_particles: str, csv_pandora: str = "", csv_pfo_links: str = ""):
"""Parse user-provided CSV text into the dict-of-arrays format expected by
``create_graph``.
Expected CSV columns for hits (X_hit) — 11 columns:
0: hit_x — hit position x [mm]
1: hit_y — hit position y [mm]
2: hit_z — hit position z [mm]
3: hit_px — hit momentum px [GeV] (0 for calo hits)
4: hit_py — hit momentum py [GeV] (0 for calo hits)
5: hit_energy — hit energy deposit [GeV]
6: hit_x_calo — hit position x at calorimeter surface [mm] (used as 3D position by the model)
7: hit_y_calo — hit position y at calorimeter surface [mm]
8: hit_z_calo — hit position z at calorimeter surface [mm]
9: (unused) — reserved column (set to 0)
10: hit_type — hit sub-detector type: 1 = ECAL, 2 = HCAL, 3 = muon system
Expected CSV columns for tracks (X_track) — 25 columns (padded with
zeros if fewer are provided; minimum 17):
0: elemtype — element type (always 1 for tracks)
1–4: (unused) — reserved columns (set to 0)
5: p — track momentum magnitude |p| [GeV]
6: px_IP — track px at interaction point [GeV]
7: py_IP — track py at interaction point [GeV]
8: pz_IP — track pz at interaction point [GeV]
9–11: (unused) — reserved columns (set to 0)
12: ref_x_calo — track reference-point x at calorimeter [mm]
13: ref_y_calo — track reference-point y at calorimeter [mm]
14: ref_z_calo — track reference-point z at calorimeter [mm]
15: chi2 — track-fit chi-squared
16: ndf — track-fit number of degrees of freedom
17–21: (unused) — reserved columns (set to 0)
22: px_calo — track momentum x component at calorimeter [GeV]
23: py_calo — track momentum y component at calorimeter [GeV]
24: pz_calo — track momentum z component at calorimeter [GeV]
Expected CSV columns for particles / MC truth (X_gen) — 18 columns:
0: pid — PDG particle ID (e.g. 211, 22, 11, 13)
1: gen_status — generator status code
2: isDecayedInCalo — 1 if decayed in calorimeter, else 0
3: isDecayedInTracker — 1 if decayed in tracker, else 0
4: theta — polar angle [rad]
5: phi — azimuthal angle [rad]
6: (unused) — reserved (set to 0)
7: (unused) — reserved (set to 0)
8: energy — true particle energy [GeV]
9: (unused) — reserved (set to 0)
10: mass — particle mass [GeV]
11: momentum — momentum magnitude |p| [GeV]
12: px — momentum x component [GeV]
13: py — momentum y component [GeV]
14: pz — momentum z component [GeV]
15: vx — production vertex x [mm]
16: vy — production vertex y [mm]
17: vz — production vertex z [mm]
PFO links (csv_pfo_links) — two lines of comma-separated integers:
Line 1: pfo_calohit — one PFO index per calorimeter hit (-1 = unassigned)
Line 2: pfo_track — one PFO index per track (-1 = unassigned)
"""
import io
import awkward as ak
def _read(text, min_cols=1):
if not text or not text.strip():
return np.zeros((0, min_cols), dtype=np.float64)
df = pd.read_csv(io.StringIO(text), header=None)
return df.values.astype(np.float64)
hits_arr = _read(csv_hits, 11)
tracks_arr = _read(csv_tracks, 25)
particles_arr = _read(csv_particles, 18)
pandora_arr = _read(csv_pandora, 9)
# Pad tracks to 25 columns if needed
if tracks_arr.shape[1] < 25 and tracks_arr.shape[0] > 0:
pad = np.zeros((tracks_arr.shape[0], 25 - tracks_arr.shape[1]))
tracks_arr = np.concatenate([tracks_arr, pad], axis=1)
# Build ygen_hit / ygen_track (particle link per hit — use -1 for unknown)
ygen_hit = np.full(len(hits_arr), -1, dtype=np.int64)
ygen_track = np.full(len(tracks_arr), -1, dtype=np.int64)
# Parse PFO link arrays (hit → Pandora cluster mapping)
pfo_calohit = np.array([], dtype=np.int64)
pfo_track = np.array([], dtype=np.int64)
if csv_pfo_links and csv_pfo_links.strip():
lines = csv_pfo_links.strip().split("\n")
if len(lines) >= 1 and lines[0].strip():
pfo_calohit = np.array(
[int(v) for v in lines[0].strip().split(",")], dtype=np.int64
)
if len(lines) >= 2 and lines[1].strip():
pfo_track = np.array(
[int(v) for v in lines[1].strip().split(",")], dtype=np.int64
)
event = {
"X_hit": hits_arr,
"X_track": tracks_arr,
"X_gen": particles_arr,
"X_pandora": pandora_arr,
"ygen_hit": ygen_hit,
"ygen_track": ygen_track,
"pfo_calohit": pfo_calohit,
"pfo_track": pfo_track,
}
return event
# ---------------------------------------------------------------------------
# Build the Gradio interface
# ---------------------------------------------------------------------------
def build_app():
with gr.Blocks(title="HitPF — Single-event MLPF Inference") as demo:
gr.Markdown(
"# HitPF — Single-event MLPF Inference\n"
"Run the GATr-based particle-flow reconstruction on a single event.\n\n"
"**Steps:** 1) Load model checkpoints 2) Select an event 3) Run inference"
)
# ---- Model loading ----
with gr.Accordion("1 · Load Model", open=True):
with gr.Row():
clustering_ckpt = gr.Textbox(
label="Clustering checkpoint (.ckpt)",
value="model_clustering.ckpt",
placeholder="/path/to/clustering.ckpt",
)
energy_pid_ckpt = gr.Textbox(
label="Energy / PID checkpoint (.ckpt) — optional",
value="model_e_pid.ckpt",
placeholder="/path/to/energy_pid.ckpt",
)
device_dd = gr.Dropdown(
choices=["cpu", "cuda:0", "cuda:1"],
value="cpu",
label="Device",
)
load_btn = gr.Button("Load model")
load_status = gr.Markdown("")
load_btn.click(
fn=load_model_ui,
inputs=[clustering_ckpt, energy_pid_ckpt, device_dd],
outputs=load_status,
)
# ---- Event selection ----
with gr.Accordion("2 · Select Event", open=True):
gr.Markdown("**Option A** — from a parquet file:")
with gr.Row():
parquet_path = gr.Textbox(
label="Parquet file path",
value="test_data.parquet",
placeholder="/path/to/events.parquet",
)
event_idx = gr.Number(label="Event index", value=0, precision=0)
parquet_info = gr.Markdown("")
parquet_path.change(
fn=_count_events_in_parquet,
inputs=parquet_path,
outputs=parquet_info,
)
load_event_btn = gr.Button("Load event from parquet")
load_event_status = gr.Markdown("")
gr.Markdown(
"---\n**Option B** — paste CSV data (one row per hit/track/particle, "
"no header, comma-separated):\n"
)
csv_hits = gr.Textbox(
label="Hits CSV (11 columns)",
lines=4,
placeholder=(
"Example (one ECAL hit, one HCAL hit):\n"
"0,0,0,0,0,1.23,1800.5,200.3,100.1,0,1\n"
"0,0,0,0,0,0.45,1900.2,-50.1,300.7,0,2"
),
)
csv_tracks = gr.Textbox(
label="Tracks CSV (25 columns; leave empty if none)",
lines=3,
placeholder=(
"Example (one track with p≈5 GeV):\n"
"1,0,0,0,0,5.0,3.0,2.0,3.3,0,0,0,1800.0,150.0,90.0,12.5,8,0,0,0,0,0,2.9,1.9,3.2"
),
)
csv_particles = gr.Textbox(
label="Particles (MC truth) CSV (18 columns; optional)",
lines=3,
placeholder=(
"Example (one pion, one photon):\n"
"211,1,0,0,1.2,0.5,0,0,5.2,0,0.1396,5.198,3.1,2.0,3.3,0,0,0\n"
"22,1,0,0,0.8,2.1,0,0,1.5,0,0,1.5,0.5,-0.3,1.38,0,0,0"
),
)
csv_pandora = gr.Textbox(
label="Pandora PFOs CSV (9 columns; optional)",
lines=3,
placeholder=(
"Columns: pid, px, py, pz, ref_x, ref_y, ref_z, energy, momentum\n"
"Example (one charged pion PFO):\n"
"211,3.0,2.0,3.3,1800.0,150.0,90.0,5.2,5.198"
),
)
csv_pfo_links = gr.Textbox(
label="Hit → Pandora Cluster links (optional; loaded from parquet)",
lines=2,
placeholder=(
"Line 1: PFO index per calo hit (comma-separated, -1 = unassigned)\n"
"Line 2: PFO index per track (comma-separated, -1 = unassigned)"
),
)
load_event_btn.click(
fn=_load_event_into_csv,
inputs=[parquet_path, event_idx],
outputs=[csv_hits, csv_tracks, csv_particles, csv_pandora, csv_pfo_links, load_event_status],
)
# ---- Run inference ----
with gr.Accordion("3 · Results", open=True):
run_btn = gr.Button("▶ Run Inference", variant="primary")
inv_mass_output = gr.Markdown("")
gr.Markdown("### Predicted Particles (HitPF)")
particles_table = gr.Dataframe(label="Predicted particles")
gr.Markdown("### MC Truth Particles")
mc_particles_table = gr.Dataframe(label="MC truth particles (for comparison)")
gr.Markdown("### Pandora Particles")
pandora_particles_table = gr.Dataframe(label="Pandora PFO particles (for comparison)")
with gr.Row():
with gr.Column():
gr.Markdown("### Hit → HitPF Cluster 3D Map")
cluster_plot = gr.Plot(label="Hit-cluster 3D scatter (color = HitPF cluster, size = energy)")
with gr.Column():
gr.Markdown("### Hit → Pandora Cluster 3D Map")
pandora_cluster_plot = gr.Plot(label="Hit-cluster 3D scatter (color = Pandora PFO, size = energy)")
gr.Markdown("### Clustering Space 3D Map")
clustering_space_plot = gr.Plot(label="Clustering space 3D scatter (GATr regressed coordinates)")
run_btn.click(
fn=run_inference_ui,
inputs=[parquet_path, event_idx, csv_hits, csv_tracks, csv_particles, csv_pandora, csv_pfo_links],
outputs=[particles_table, cluster_plot, clustering_space_plot, pandora_cluster_plot, mc_particles_table, pandora_particles_table, inv_mass_output],
)
return demo
# ---------------------------------------------------------------------------
if __name__ == "__main__":
ap = argparse.ArgumentParser(description="HitPF Gradio UI")
ap.add_argument("--device", default="cpu", help="Default device (cpu / cuda:0 / …)")
ap.add_argument("--share", action="store_true", help="Create a public Gradio link")
cli_args = ap.parse_args()
_set_device(cli_args.device)
demo = build_app()
demo.launch(share=cli_args.share)