Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| """ | |
| Gradio UI for single-event MLPF inference. | |
| Launch with: | |
| python app.py [--device cpu] | |
| The UI lets you: | |
| 1. Load an event from a parquet file (pick file + event index), **or** | |
| paste hit / track / particle data in CSV format. | |
| 2. (Optionally) load pre-trained model checkpoints. | |
| 3. Run inference → view predicted particles and the hit→cluster mapping. | |
| """ | |
| import argparse | |
| import os | |
| import shutil | |
| import traceback | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from huggingface_hub import hf_hub_download | |
| # --------------------------------------------------------------------------- | |
| # Auto-download demo files from Hugging Face Hub if they are not present | |
| # --------------------------------------------------------------------------- | |
| _HF_REPO_ID = "gregorkrzmanc/hitpf_demo_files" | |
| _DEMO_FILES = [ | |
| "model_clustering.ckpt", | |
| "model_e_pid.ckpt", | |
| "test_data.parquet", | |
| ] | |
| def _ensure_demo_files(dest_dir: str = ".") -> None: | |
| """Download demo files from Hugging Face Hub if they don't already exist.""" | |
| for fname in _DEMO_FILES: | |
| dest = os.path.join(dest_dir, fname) | |
| if not os.path.isfile(dest): | |
| try: | |
| print(f"Downloading {fname} from HF Hub ({_HF_REPO_ID}) …") | |
| downloaded = hf_hub_download( | |
| repo_id=_HF_REPO_ID, | |
| filename=fname, | |
| repo_type="dataset", | |
| ) | |
| shutil.copy(downloaded, dest) | |
| print(f" → saved to {dest}") | |
| except Exception as exc: | |
| print(f" ⚠️ Could not download {fname}: {exc}") | |
| _ensure_demo_files() | |
| # --------------------------------------------------------------------------- | |
| # Global state – filled lazily | |
| # --------------------------------------------------------------------------- | |
| _MODEL = None | |
| _ARGS = None | |
| _DEVICE = "cpu" | |
| def _set_device(device: str): | |
| global _DEVICE | |
| _DEVICE = device | |
| # --------------------------------------------------------------------------- | |
| # Model loading | |
| # --------------------------------------------------------------------------- | |
| def load_model_ui(clustering_ckpt: str, energy_pid_ckpt: str, device: str): | |
| """Load model from checkpoint paths (called by the UI button).""" | |
| global _MODEL, _ARGS, _DEVICE | |
| _DEVICE = device or "cpu" | |
| if not clustering_ckpt or not os.path.isfile(clustering_ckpt): | |
| return "⚠️ Please provide a valid path to the clustering checkpoint." | |
| energy_pid = energy_pid_ckpt if (energy_pid_ckpt and os.path.isfile(energy_pid_ckpt)) else None | |
| try: | |
| from src.inference import load_model | |
| _MODEL, _ARGS = load_model( | |
| clustering_ckpt=clustering_ckpt, | |
| energy_pid_ckpt=energy_pid, | |
| device=_DEVICE, | |
| ) | |
| msg = f"✅ Model loaded on **{_DEVICE}**" | |
| if energy_pid: | |
| msg += " (clustering + energy/PID correction)" | |
| else: | |
| msg += " (clustering only — no energy/PID correction)" | |
| return msg | |
| except Exception: | |
| return f"❌ Failed to load model:\n```\n{traceback.format_exc()}\n```" | |
| # --------------------------------------------------------------------------- | |
| # Event loading helpers | |
| # --------------------------------------------------------------------------- | |
| def _count_events_in_parquet(parquet_path: str) -> str: | |
| """Return a short info string about the parquet file.""" | |
| if not parquet_path or not os.path.isfile(parquet_path): | |
| return "No file selected" | |
| try: | |
| from src.inference import load_event_from_parquet | |
| from src.data.fileio import _read_parquet | |
| table = _read_parquet(parquet_path) | |
| n = len(table["X_track"]) | |
| return f"File has **{n}** events (indices 0–{n-1})" | |
| except Exception as e: | |
| return f"Error reading file: {e}" | |
| def _load_event_into_csv(parquet_path: str, event_index: int): | |
| """Load an event from a parquet file and return CSV strings for the text fields.""" | |
| if not parquet_path or not os.path.isfile(parquet_path): | |
| return "", "", "", "", "", "⚠️ Please provide a valid parquet file path." | |
| try: | |
| from src.inference import load_event_from_parquet | |
| event = load_event_from_parquet(parquet_path, int(event_index)) | |
| hits_arr = np.asarray(event.get("X_hit", [])) | |
| tracks_arr = np.asarray(event.get("X_track", [])) | |
| particles_arr = np.asarray(event.get("X_gen", [])) | |
| pandora_arr = np.asarray(event.get("X_pandora", [])) | |
| def _arr_to_csv(arr): | |
| if arr.ndim != 2: | |
| return "" | |
| return "\n".join(",".join(str(v) for v in row) for row in arr) | |
| def _1d_to_csv(arr): | |
| if len(arr) == 0: | |
| return "" | |
| return ",".join(str(int(v)) for v in arr) | |
| pfo_calohit = np.asarray(event.get("pfo_calohit", []), dtype=np.int64) | |
| pfo_track = np.asarray(event.get("pfo_track", []), dtype=np.int64) | |
| calohit_csv = _1d_to_csv(pfo_calohit) | |
| track_csv = _1d_to_csv(pfo_track) | |
| if calohit_csv and track_csv: | |
| pfo_links_csv = calohit_csv + "\n" + track_csv | |
| elif calohit_csv: | |
| pfo_links_csv = calohit_csv | |
| elif track_csv: | |
| pfo_links_csv = "\n" + track_csv | |
| else: | |
| pfo_links_csv = "" | |
| return ( | |
| _arr_to_csv(hits_arr), | |
| _arr_to_csv(tracks_arr), | |
| _arr_to_csv(particles_arr), | |
| _arr_to_csv(pandora_arr), | |
| pfo_links_csv, | |
| f"✅ Loaded event **{int(event_index)}**: " | |
| f"{hits_arr.shape[0] if hits_arr.ndim == 2 else 0} hits, " | |
| f"{tracks_arr.shape[0] if tracks_arr.ndim == 2 else 0} tracks, " | |
| f"{particles_arr.shape[0] if particles_arr.ndim == 2 else 0} MC particles, " | |
| f"{pandora_arr.shape[0] if pandora_arr.ndim == 2 else 0} Pandora PFOs", | |
| ) | |
| except Exception as e: | |
| return "", "", "", "", "", f"❌ Error loading event: {e}" | |
| def _build_cluster_plot(hit_cluster_df: pd.DataFrame) -> go.Figure: | |
| """Build an interactive 3D scatter plot of hits colored by cluster ID.""" | |
| if hit_cluster_df.empty: | |
| fig = go.Figure() | |
| fig.update_layout(title="No hit data available", height=600) | |
| return fig | |
| df = hit_cluster_df.copy() | |
| # Drop rows with NaN/Inf coordinates | |
| for col in ("x", "y", "z", "hit_energy"): | |
| df[col] = pd.to_numeric(df[col], errors="coerce") | |
| df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["x", "y", "z", "hit_energy"]) | |
| if df.empty: | |
| fig = go.Figure() | |
| fig.update_layout(title="No valid hit data (all NaN/Inf)", height=600) | |
| return fig | |
| # Normalize hit energies for marker sizes | |
| energies = df["hit_energy"].values.astype(float) | |
| e_min, e_max = float(energies.min()), float(energies.max()) | |
| if e_max > e_min: | |
| norm_e = (energies - e_min) / (e_max - e_min) | |
| else: | |
| norm_e = np.ones_like(energies) * 0.5 # midpoint when all equal | |
| marker_sizes = 3 + norm_e * 12 # min size 3, max size 15 | |
| # Build per-hit hover text (avoids mixed-type customdata serialization issues) | |
| df["_hover"] = ( | |
| "<b>" + df["hit_type"].astype(str) + "</b> hit #" + df["hit_index"].astype(int).astype(str) + "<br>" | |
| + "Cluster: " + df["cluster_id"].astype(int).astype(str) + "<br>" | |
| + "Energy: " + df["hit_energy"].map(lambda v: f"{v:.4f}") + "<br>" | |
| + "x: " + df["x"].map(lambda v: f"{v:.2f}") | |
| + ", y: " + df["y"].map(lambda v: f"{v:.2f}") | |
| + ", z: " + df["z"].map(lambda v: f"{v:.2f}") | |
| ) | |
| cluster_ids = df["cluster_id"].values | |
| unique_clusters = sorted(set(int(c) for c in cluster_ids)) | |
| fig = go.Figure() | |
| for cid in unique_clusters: | |
| mask = cluster_ids == cid | |
| subset = df[mask] | |
| sizes = marker_sizes[mask].tolist() | |
| label = "noise" if cid == 0 else f"cluster {cid}" | |
| fig.add_trace(go.Scatter3d( | |
| x=subset["x"].tolist(), | |
| y=subset["y"].tolist(), | |
| z=subset["z"].tolist(), | |
| mode="markers", | |
| name=label, | |
| marker=dict(size=sizes, opacity=0.8), | |
| hovertext=subset["_hover"].tolist(), | |
| hoverinfo="text", | |
| )) | |
| fig.update_layout( | |
| title="Hit → Cluster 3D Map", | |
| scene=dict(xaxis_title="x", yaxis_title="y", zaxis_title="z"), | |
| legend_title="Cluster", | |
| height=600, | |
| margin=dict(l=0, r=0, t=40, b=0), | |
| ) | |
| return fig | |
| def _build_pandora_cluster_plot(hit_cluster_df: pd.DataFrame) -> go.Figure: | |
| """Build an interactive 3D scatter plot of hits colored by Pandora cluster ID.""" | |
| if hit_cluster_df.empty or "pandora_cluster_id" not in hit_cluster_df.columns: | |
| fig = go.Figure() | |
| fig.update_layout(title="No Pandora cluster data available", height=600) | |
| return fig | |
| df = hit_cluster_df.copy() | |
| # Only keep rows that have valid Pandora assignments (pandora_cluster_id >= 0) | |
| for col in ("x", "y", "z", "hit_energy"): | |
| df[col] = pd.to_numeric(df[col], errors="coerce") | |
| df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["x", "y", "z", "hit_energy"]) | |
| if df.empty: | |
| fig = go.Figure() | |
| fig.update_layout(title="No valid hit data for Pandora plot (all NaN/Inf)", height=600) | |
| return fig | |
| # Normalize hit energies for marker sizes | |
| energies = df["hit_energy"].values.astype(float) | |
| e_min, e_max = float(energies.min()), float(energies.max()) | |
| if e_max > e_min: | |
| norm_e = (energies - e_min) / (e_max - e_min) | |
| else: | |
| norm_e = np.ones_like(energies) * 0.5 | |
| marker_sizes = 3 + norm_e * 12 | |
| # Build per-hit hover text | |
| df["_hover"] = ( | |
| "<b>" + df["hit_type"].astype(str) + "</b> hit #" + df["hit_index"].astype(int).astype(str) + "<br>" | |
| + "Pandora cluster: " + df["pandora_cluster_id"].astype(int).astype(str) + "<br>" | |
| + "Energy: " + df["hit_energy"].map(lambda v: f"{v:.4f}") + "<br>" | |
| + "x: " + df["x"].map(lambda v: f"{v:.2f}") | |
| + ", y: " + df["y"].map(lambda v: f"{v:.2f}") | |
| + ", z: " + df["z"].map(lambda v: f"{v:.2f}") | |
| ) | |
| pandora_ids = df["pandora_cluster_id"].values | |
| unique_clusters = sorted(set(int(c) for c in pandora_ids)) | |
| fig = go.Figure() | |
| for cid in unique_clusters: | |
| mask = pandora_ids == cid | |
| subset = df[mask] | |
| sizes = marker_sizes[mask].tolist() | |
| label = "unassigned" if cid == -1 else f"PFO {cid}" | |
| fig.add_trace(go.Scatter3d( | |
| x=subset["x"].tolist(), | |
| y=subset["y"].tolist(), | |
| z=subset["z"].tolist(), | |
| mode="markers", | |
| name=label, | |
| marker=dict(size=sizes, opacity=0.8), | |
| hovertext=subset["_hover"].tolist(), | |
| hoverinfo="text", | |
| )) | |
| fig.update_layout( | |
| title="Hit → Pandora Cluster 3D Map", | |
| scene=dict(xaxis_title="x", yaxis_title="y", zaxis_title="z"), | |
| legend_title="Pandora PFO", | |
| height=600, | |
| margin=dict(l=0, r=0, t=40, b=0), | |
| ) | |
| return fig | |
| def _build_clustering_space_plot(hit_cluster_df: pd.DataFrame) -> go.Figure: | |
| """Build an interactive 3D scatter plot of hits in the learned clustering space.""" | |
| if hit_cluster_df.empty or "cluster_x" not in hit_cluster_df.columns: | |
| fig = go.Figure() | |
| fig.update_layout(title="No clustering-space data available", height=600) | |
| return fig | |
| df = hit_cluster_df.copy() | |
| # Drop rows with NaN/Inf coordinates | |
| for col in ("cluster_x", "cluster_y", "cluster_z", "hit_energy"): | |
| df[col] = pd.to_numeric(df[col], errors="coerce") | |
| df = df.replace([np.inf, -np.inf], np.nan).dropna( | |
| subset=["cluster_x", "cluster_y", "cluster_z", "hit_energy"] | |
| ) | |
| if df.empty: | |
| fig = go.Figure() | |
| fig.update_layout(title="No valid clustering-space data (all NaN/Inf)", height=600) | |
| return fig | |
| # Normalize hit energies for marker sizes | |
| energies = df["hit_energy"].values.astype(float) | |
| e_min, e_max = float(energies.min()), float(energies.max()) | |
| if e_max > e_min: | |
| norm_e = (energies - e_min) / (e_max - e_min) | |
| else: | |
| norm_e = np.ones_like(energies) * 0.5 | |
| marker_sizes = 3 + norm_e * 12 | |
| # Build per-hit hover text | |
| df["_hover"] = ( | |
| "<b>" + df["hit_type"].astype(str) + "</b> hit #" + df["hit_index"].astype(int).astype(str) + "<br>" | |
| + "Cluster: " + df["cluster_id"].astype(int).astype(str) + "<br>" | |
| + "Energy: " + df["hit_energy"].map(lambda v: f"{v:.4f}") + "<br>" | |
| + "cluster_x: " + df["cluster_x"].map(lambda v: f"{v:.4f}") | |
| + ", cluster_y: " + df["cluster_y"].map(lambda v: f"{v:.4f}") | |
| + ", cluster_z: " + df["cluster_z"].map(lambda v: f"{v:.4f}") | |
| ) | |
| cluster_ids = df["cluster_id"].values | |
| unique_clusters = sorted(set(int(c) for c in cluster_ids)) | |
| fig = go.Figure() | |
| for cid in unique_clusters: | |
| mask = cluster_ids == cid | |
| subset = df[mask] | |
| sizes = marker_sizes[mask].tolist() | |
| label = "noise" if cid == 0 else f"cluster {cid}" | |
| fig.add_trace(go.Scatter3d( | |
| x=subset["cluster_x"].tolist(), | |
| y=subset["cluster_y"].tolist(), | |
| z=subset["cluster_z"].tolist(), | |
| mode="markers", | |
| name=label, | |
| marker=dict(size=sizes, opacity=0.8), | |
| hovertext=subset["_hover"].tolist(), | |
| hoverinfo="text", | |
| )) | |
| fig.update_layout( | |
| title="Clustering Space 3D Map (GATr regressed coordinates)", | |
| scene=dict( | |
| xaxis_title="cluster_x", | |
| yaxis_title="cluster_y", | |
| zaxis_title="cluster_z", | |
| ), | |
| legend_title="Cluster", | |
| height=600, | |
| margin=dict(l=0, r=0, t=40, b=0), | |
| ) | |
| return fig | |
| # --------------------------------------------------------------------------- | |
| # Main inference entry point for the UI | |
| # --------------------------------------------------------------------------- | |
| def _compute_inv_mass(df, e_col, px_col, py_col, pz_col): | |
| """Compute the invariant mass of a system of particles in GeV. | |
| Returns the scalar invariant mass m = sqrt(max((ΣE)²−(Σpx)²−(Σpy)²−(Σpz)², 0)), | |
| or *None* when *df* is empty or the required columns are absent. | |
| """ | |
| if df.empty: | |
| return None | |
| for col in (e_col, px_col, py_col, pz_col): | |
| if col not in df.columns: | |
| return None | |
| E = float(df[e_col].sum()) | |
| px = float(df[px_col].sum()) | |
| py = float(df[py_col].sum()) | |
| pz = float(df[pz_col].sum()) | |
| m2 = E ** 2 - px ** 2 - py ** 2 - pz ** 2 | |
| return float(np.sqrt(max(m2, 0.0))) | |
| def _fmt_mass(val): | |
| """Format an invariant-mass value (float or None) as a GeV string.""" | |
| return f"{val:.4f} GeV" if val is not None else "N/A" | |
| def run_inference_ui( | |
| parquet_path: str, | |
| event_index: int, | |
| csv_hits: str, | |
| csv_tracks: str, | |
| csv_particles: str, | |
| csv_pandora: str, | |
| csv_pfo_links: str = "", | |
| ): | |
| """Run inference on a single event, return predicted particles, 3D plots, MC particles and Pandora particles. | |
| Returns | |
| ------- | |
| particles_df : pandas.DataFrame | |
| cluster_fig : plotly.graph_objects.Figure | |
| clustering_space_fig : plotly.graph_objects.Figure | |
| pandora_cluster_fig : plotly.graph_objects.Figure | |
| mc_particles_df : pandas.DataFrame | |
| pandora_particles_df : pandas.DataFrame | |
| inv_mass_summary : str | |
| """ | |
| global _MODEL, _ARGS, _DEVICE | |
| empty_fig = go.Figure() | |
| if _MODEL is None: | |
| return ( | |
| pd.DataFrame({"error": ["Model not loaded. Please load a model first."]}), | |
| empty_fig, | |
| empty_fig, | |
| empty_fig, | |
| pd.DataFrame(), | |
| pd.DataFrame(), | |
| "", | |
| ) | |
| try: | |
| from src.inference import load_event_from_parquet, run_single_event_inference | |
| # Decide input source | |
| use_parquet = parquet_path and os.path.isfile(parquet_path) | |
| use_csv = bool(csv_hits and csv_hits.strip()) | |
| if not use_parquet and not use_csv: | |
| return ( | |
| pd.DataFrame({"error": ["Provide a parquet file or paste CSV hit data."]}), | |
| empty_fig, | |
| empty_fig, | |
| empty_fig, | |
| pd.DataFrame(), | |
| pd.DataFrame(), | |
| "", | |
| ) | |
| if use_csv: | |
| event = _parse_csv_event(csv_hits, csv_tracks, csv_particles, csv_pandora, csv_pfo_links) | |
| elif use_parquet: | |
| event = load_event_from_parquet(parquet_path, int(event_index)) | |
| particles_df, hit_cluster_df, mc_particles_df, pandora_particles_df = run_single_event_inference( | |
| event, _MODEL, _ARGS, device=_DEVICE, | |
| ) | |
| if particles_df.empty: | |
| particles_df = pd.DataFrame({"info": ["Event produced no clusters (empty graph)."]}) | |
| cluster_fig = _build_cluster_plot(hit_cluster_df) | |
| clustering_space_fig = _build_clustering_space_plot(hit_cluster_df) | |
| pandora_cluster_fig = _build_pandora_cluster_plot(hit_cluster_df) | |
| # Compute invariant masses [GeV] | |
| m_true = _compute_inv_mass(mc_particles_df, "energy", "px", "py", "pz") | |
| # HitPF uses corrected_energy when available, otherwise energy_sum_hits | |
| hitpf_e_col = "corrected_energy" if "corrected_energy" in particles_df.columns else "energy_sum_hits" | |
| m_reco_hitpf = _compute_inv_mass(particles_df, hitpf_e_col, "px", "py", "pz") | |
| m_reco_pandora = _compute_inv_mass(pandora_particles_df, "energy", "px", "py", "pz") | |
| inv_mass_summary = ( | |
| f"**Invariant mass (sum of all particle 4-vectors)**\n\n" | |
| f"| Algorithm | m [GeV] |\n" | |
| f"|---|---|\n" | |
| f"| m_true (MC truth) | {_fmt_mass(m_true)} |\n" | |
| f"| m_reco (HitPF) | {_fmt_mass(m_reco_hitpf)} |\n" | |
| f"| m_reco (Pandora) | {_fmt_mass(m_reco_pandora)} |" | |
| ) | |
| return particles_df, cluster_fig, clustering_space_fig, pandora_cluster_fig, mc_particles_df, pandora_particles_df, inv_mass_summary | |
| except Exception: | |
| err = traceback.format_exc() | |
| return ( | |
| pd.DataFrame({"error": [err]}), | |
| empty_fig, | |
| empty_fig, | |
| empty_fig, | |
| pd.DataFrame(), | |
| pd.DataFrame(), | |
| "", | |
| ) | |
| def _parse_csv_event(csv_hits: str, csv_tracks: str, csv_particles: str, csv_pandora: str = "", csv_pfo_links: str = ""): | |
| """Parse user-provided CSV text into the dict-of-arrays format expected by | |
| ``create_graph``. | |
| Expected CSV columns for hits (X_hit) — 11 columns: | |
| 0: hit_x — hit position x [mm] | |
| 1: hit_y — hit position y [mm] | |
| 2: hit_z — hit position z [mm] | |
| 3: hit_px — hit momentum px [GeV] (0 for calo hits) | |
| 4: hit_py — hit momentum py [GeV] (0 for calo hits) | |
| 5: hit_energy — hit energy deposit [GeV] | |
| 6: hit_x_calo — hit position x at calorimeter surface [mm] (used as 3D position by the model) | |
| 7: hit_y_calo — hit position y at calorimeter surface [mm] | |
| 8: hit_z_calo — hit position z at calorimeter surface [mm] | |
| 9: (unused) — reserved column (set to 0) | |
| 10: hit_type — hit sub-detector type: 1 = ECAL, 2 = HCAL, 3 = muon system | |
| Expected CSV columns for tracks (X_track) — 25 columns (padded with | |
| zeros if fewer are provided; minimum 17): | |
| 0: elemtype — element type (always 1 for tracks) | |
| 1–4: (unused) — reserved columns (set to 0) | |
| 5: p — track momentum magnitude |p| [GeV] | |
| 6: px_IP — track px at interaction point [GeV] | |
| 7: py_IP — track py at interaction point [GeV] | |
| 8: pz_IP — track pz at interaction point [GeV] | |
| 9–11: (unused) — reserved columns (set to 0) | |
| 12: ref_x_calo — track reference-point x at calorimeter [mm] | |
| 13: ref_y_calo — track reference-point y at calorimeter [mm] | |
| 14: ref_z_calo — track reference-point z at calorimeter [mm] | |
| 15: chi2 — track-fit chi-squared | |
| 16: ndf — track-fit number of degrees of freedom | |
| 17–21: (unused) — reserved columns (set to 0) | |
| 22: px_calo — track momentum x component at calorimeter [GeV] | |
| 23: py_calo — track momentum y component at calorimeter [GeV] | |
| 24: pz_calo — track momentum z component at calorimeter [GeV] | |
| Expected CSV columns for particles / MC truth (X_gen) — 18 columns: | |
| 0: pid — PDG particle ID (e.g. 211, 22, 11, 13) | |
| 1: gen_status — generator status code | |
| 2: isDecayedInCalo — 1 if decayed in calorimeter, else 0 | |
| 3: isDecayedInTracker — 1 if decayed in tracker, else 0 | |
| 4: theta — polar angle [rad] | |
| 5: phi — azimuthal angle [rad] | |
| 6: (unused) — reserved (set to 0) | |
| 7: (unused) — reserved (set to 0) | |
| 8: energy — true particle energy [GeV] | |
| 9: (unused) — reserved (set to 0) | |
| 10: mass — particle mass [GeV] | |
| 11: momentum — momentum magnitude |p| [GeV] | |
| 12: px — momentum x component [GeV] | |
| 13: py — momentum y component [GeV] | |
| 14: pz — momentum z component [GeV] | |
| 15: vx — production vertex x [mm] | |
| 16: vy — production vertex y [mm] | |
| 17: vz — production vertex z [mm] | |
| PFO links (csv_pfo_links) — two lines of comma-separated integers: | |
| Line 1: pfo_calohit — one PFO index per calorimeter hit (-1 = unassigned) | |
| Line 2: pfo_track — one PFO index per track (-1 = unassigned) | |
| """ | |
| import io | |
| import awkward as ak | |
| def _read(text, min_cols=1): | |
| if not text or not text.strip(): | |
| return np.zeros((0, min_cols), dtype=np.float64) | |
| df = pd.read_csv(io.StringIO(text), header=None) | |
| return df.values.astype(np.float64) | |
| hits_arr = _read(csv_hits, 11) | |
| tracks_arr = _read(csv_tracks, 25) | |
| particles_arr = _read(csv_particles, 18) | |
| pandora_arr = _read(csv_pandora, 9) | |
| # Pad tracks to 25 columns if needed | |
| if tracks_arr.shape[1] < 25 and tracks_arr.shape[0] > 0: | |
| pad = np.zeros((tracks_arr.shape[0], 25 - tracks_arr.shape[1])) | |
| tracks_arr = np.concatenate([tracks_arr, pad], axis=1) | |
| # Build ygen_hit / ygen_track (particle link per hit — use -1 for unknown) | |
| ygen_hit = np.full(len(hits_arr), -1, dtype=np.int64) | |
| ygen_track = np.full(len(tracks_arr), -1, dtype=np.int64) | |
| # Parse PFO link arrays (hit → Pandora cluster mapping) | |
| pfo_calohit = np.array([], dtype=np.int64) | |
| pfo_track = np.array([], dtype=np.int64) | |
| if csv_pfo_links and csv_pfo_links.strip(): | |
| lines = csv_pfo_links.strip().split("\n") | |
| if len(lines) >= 1 and lines[0].strip(): | |
| pfo_calohit = np.array( | |
| [int(v) for v in lines[0].strip().split(",")], dtype=np.int64 | |
| ) | |
| if len(lines) >= 2 and lines[1].strip(): | |
| pfo_track = np.array( | |
| [int(v) for v in lines[1].strip().split(",")], dtype=np.int64 | |
| ) | |
| event = { | |
| "X_hit": hits_arr, | |
| "X_track": tracks_arr, | |
| "X_gen": particles_arr, | |
| "X_pandora": pandora_arr, | |
| "ygen_hit": ygen_hit, | |
| "ygen_track": ygen_track, | |
| "pfo_calohit": pfo_calohit, | |
| "pfo_track": pfo_track, | |
| } | |
| return event | |
| # --------------------------------------------------------------------------- | |
| # Build the Gradio interface | |
| # --------------------------------------------------------------------------- | |
| def build_app(): | |
| with gr.Blocks(title="HitPF — Single-event MLPF Inference") as demo: | |
| gr.Markdown( | |
| "# HitPF — Single-event MLPF Inference\n" | |
| "Run the GATr-based particle-flow reconstruction on a single event.\n\n" | |
| "**Steps:** 1) Load model checkpoints 2) Select an event 3) Run inference" | |
| ) | |
| # ---- Model loading ---- | |
| with gr.Accordion("1 · Load Model", open=True): | |
| with gr.Row(): | |
| clustering_ckpt = gr.Textbox( | |
| label="Clustering checkpoint (.ckpt)", | |
| value="model_clustering.ckpt", | |
| placeholder="/path/to/clustering.ckpt", | |
| ) | |
| energy_pid_ckpt = gr.Textbox( | |
| label="Energy / PID checkpoint (.ckpt) — optional", | |
| value="model_e_pid.ckpt", | |
| placeholder="/path/to/energy_pid.ckpt", | |
| ) | |
| device_dd = gr.Dropdown( | |
| choices=["cpu", "cuda:0", "cuda:1"], | |
| value="cpu", | |
| label="Device", | |
| ) | |
| load_btn = gr.Button("Load model") | |
| load_status = gr.Markdown("") | |
| load_btn.click( | |
| fn=load_model_ui, | |
| inputs=[clustering_ckpt, energy_pid_ckpt, device_dd], | |
| outputs=load_status, | |
| ) | |
| # ---- Event selection ---- | |
| with gr.Accordion("2 · Select Event", open=True): | |
| gr.Markdown("**Option A** — from a parquet file:") | |
| with gr.Row(): | |
| parquet_path = gr.Textbox( | |
| label="Parquet file path", | |
| value="test_data.parquet", | |
| placeholder="/path/to/events.parquet", | |
| ) | |
| event_idx = gr.Number(label="Event index", value=0, precision=0) | |
| parquet_info = gr.Markdown("") | |
| parquet_path.change( | |
| fn=_count_events_in_parquet, | |
| inputs=parquet_path, | |
| outputs=parquet_info, | |
| ) | |
| load_event_btn = gr.Button("Load event from parquet") | |
| load_event_status = gr.Markdown("") | |
| gr.Markdown( | |
| "---\n**Option B** — paste CSV data (one row per hit/track/particle, " | |
| "no header, comma-separated):\n" | |
| ) | |
| csv_hits = gr.Textbox( | |
| label="Hits CSV (11 columns)", | |
| lines=4, | |
| placeholder=( | |
| "Example (one ECAL hit, one HCAL hit):\n" | |
| "0,0,0,0,0,1.23,1800.5,200.3,100.1,0,1\n" | |
| "0,0,0,0,0,0.45,1900.2,-50.1,300.7,0,2" | |
| ), | |
| ) | |
| csv_tracks = gr.Textbox( | |
| label="Tracks CSV (25 columns; leave empty if none)", | |
| lines=3, | |
| placeholder=( | |
| "Example (one track with p≈5 GeV):\n" | |
| "1,0,0,0,0,5.0,3.0,2.0,3.3,0,0,0,1800.0,150.0,90.0,12.5,8,0,0,0,0,0,2.9,1.9,3.2" | |
| ), | |
| ) | |
| csv_particles = gr.Textbox( | |
| label="Particles (MC truth) CSV (18 columns; optional)", | |
| lines=3, | |
| placeholder=( | |
| "Example (one pion, one photon):\n" | |
| "211,1,0,0,1.2,0.5,0,0,5.2,0,0.1396,5.198,3.1,2.0,3.3,0,0,0\n" | |
| "22,1,0,0,0.8,2.1,0,0,1.5,0,0,1.5,0.5,-0.3,1.38,0,0,0" | |
| ), | |
| ) | |
| csv_pandora = gr.Textbox( | |
| label="Pandora PFOs CSV (9 columns; optional)", | |
| lines=3, | |
| placeholder=( | |
| "Columns: pid, px, py, pz, ref_x, ref_y, ref_z, energy, momentum\n" | |
| "Example (one charged pion PFO):\n" | |
| "211,3.0,2.0,3.3,1800.0,150.0,90.0,5.2,5.198" | |
| ), | |
| ) | |
| csv_pfo_links = gr.Textbox( | |
| label="Hit → Pandora Cluster links (optional; loaded from parquet)", | |
| lines=2, | |
| placeholder=( | |
| "Line 1: PFO index per calo hit (comma-separated, -1 = unassigned)\n" | |
| "Line 2: PFO index per track (comma-separated, -1 = unassigned)" | |
| ), | |
| ) | |
| load_event_btn.click( | |
| fn=_load_event_into_csv, | |
| inputs=[parquet_path, event_idx], | |
| outputs=[csv_hits, csv_tracks, csv_particles, csv_pandora, csv_pfo_links, load_event_status], | |
| ) | |
| # ---- Run inference ---- | |
| with gr.Accordion("3 · Results", open=True): | |
| run_btn = gr.Button("▶ Run Inference", variant="primary") | |
| inv_mass_output = gr.Markdown("") | |
| gr.Markdown("### Predicted Particles (HitPF)") | |
| particles_table = gr.Dataframe(label="Predicted particles") | |
| gr.Markdown("### MC Truth Particles") | |
| mc_particles_table = gr.Dataframe(label="MC truth particles (for comparison)") | |
| gr.Markdown("### Pandora Particles") | |
| pandora_particles_table = gr.Dataframe(label="Pandora PFO particles (for comparison)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Hit → HitPF Cluster 3D Map") | |
| cluster_plot = gr.Plot(label="Hit-cluster 3D scatter (color = HitPF cluster, size = energy)") | |
| with gr.Column(): | |
| gr.Markdown("### Hit → Pandora Cluster 3D Map") | |
| pandora_cluster_plot = gr.Plot(label="Hit-cluster 3D scatter (color = Pandora PFO, size = energy)") | |
| gr.Markdown("### Clustering Space 3D Map") | |
| clustering_space_plot = gr.Plot(label="Clustering space 3D scatter (GATr regressed coordinates)") | |
| run_btn.click( | |
| fn=run_inference_ui, | |
| inputs=[parquet_path, event_idx, csv_hits, csv_tracks, csv_particles, csv_pandora, csv_pfo_links], | |
| outputs=[particles_table, cluster_plot, clustering_space_plot, pandora_cluster_plot, mc_particles_table, pandora_particles_table, inv_mass_output], | |
| ) | |
| return demo | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| ap = argparse.ArgumentParser(description="HitPF Gradio UI") | |
| ap.add_argument("--device", default="cpu", help="Default device (cpu / cuda:0 / …)") | |
| ap.add_argument("--share", action="store_true", help="Create a public Gradio link") | |
| cli_args = ap.parse_args() | |
| _set_device(cli_args.device) | |
| demo = build_app() | |
| demo.launch(share=cli_args.share) | |