""" ColliderML Event Display Interactive 3D visualisation of single events from ColliderML datasets. Loads tracker hits, truth particles, and reconstructed tracks from a chosen event and renders them with Plotly. Runs as a Gradio HuggingFace Space. """ from functools import lru_cache from pathlib import Path from typing import Any, Dict import gradio as gr import numpy as np import plotly.graph_objects as go import pyarrow as pa import pyarrow.parquet as pq try: import colliderml HAS_COLLIDERML = True except ImportError: HAS_COLLIDERML = False # Datasets exposed in the UI. Kept short so the dropdown isn't overwhelming. DATASETS = [ "ttbar_pu0", "ttbar_pu200", "higgs_portal_pu10", "zmumu_pu0", "diphoton_pu0", "single_muon_pu0", ] # How many events to cache per dataset on the Space host (limits memory). EVENTS_PER_DATASET = 50 # Local cache directory (populated by cache_events.py at build time). CACHE_DIR = Path(__file__).parent / "_cached_events" def _frame_to_arrow(frame: Any) -> pa.Table: """Normalise a Polars frame (or already-pyarrow table) into a pyarrow Table. ``colliderml.load`` returns Polars frames in v0.4.0+; the rest of this Space is pyarrow-native because we want the row-level ``filter`` / ``to_pandas`` conveniences. """ if isinstance(frame, pa.Table): return frame if hasattr(frame, "to_arrow"): collected = frame.collect() if hasattr(frame, "collect") else frame return collected.to_arrow() raise TypeError(f"Cannot convert {type(frame).__name__} to pyarrow.Table") @lru_cache(maxsize=len(DATASETS)) def _load_dataset(name: str) -> Dict[str, pa.Table]: """Return a dict of pyarrow Tables for the given dataset, cached in memory. Tries the local cache first (pre-populated by cache_events.py), then falls back to ``colliderml.load()`` which downloads on demand. """ local = CACHE_DIR / name if local.exists(): tables: Dict[str, pa.Table] = {} for parquet in local.glob("*.parquet"): tables[parquet.stem] = pq.read_table(str(parquet)) if tables: return tables if not HAS_COLLIDERML: return {} try: polars_frames = colliderml.load( name, tables=["tracker_hits", "particles", "tracks"], max_events=EVENTS_PER_DATASET, ) except Exception as exc: print(f"Failed to load {name}: {exc}") return {} return {key: _frame_to_arrow(frame) for key, frame in polars_frames.items()} def _filter_event(table, event_id): """Return rows where event_id == event_id, as a pandas DataFrame.""" if table is None: return None col_names = set(table.column_names) if "event_id" not in col_names: return table.to_pandas() mask = np.asarray(table.column("event_id")) == event_id return table.filter(mask).to_pandas() def _track_polyline(track, n_points=40): """Approximate a reconstructed track as a 3D polyline. Uses a simple helical extrapolation from the perigee parameters (d0, z0, phi, theta, qop). This is a visualisation aid, not physics. """ d0 = track.get("d0", 0.0) z0 = track.get("z0", 0.0) phi = track.get("phi", 0.0) theta = track.get("theta", np.pi / 2) qop = track.get("qop", 1e-6) # Straight-line projection to ~1 m (good enough for a visual). s = np.linspace(0, 1000, n_points) # mm x = -d0 * np.sin(phi) + s * np.sin(theta) * np.cos(phi) y = d0 * np.cos(phi) + s * np.sin(theta) * np.sin(phi) z = z0 + s * np.cos(theta) # Apply a small curvature based on qop (1/GeV momentum magnitude). if abs(qop) > 1e-9: curvature = float(qop) * 0.3 # rough mm^-1 at 1 T x += 0.5 * curvature * (s ** 2) * (-np.sin(phi)) y += 0.5 * curvature * (s ** 2) * np.cos(phi) return x, y, z def render_event(dataset_name, event_id): """Build the 3D figure for the selected dataset/event.""" tables = _load_dataset(dataset_name) if not tables: fig = go.Figure() fig.update_layout( title=f"No data available for {dataset_name}", height=700, ) return fig hits_df = _filter_event(tables.get("tracker_hits"), event_id) particles_df = _filter_event(tables.get("particles"), event_id) tracks_df = _filter_event(tables.get("tracks"), event_id) fig = go.Figure() # Tracker hits — main point cloud. if hits_df is not None and len(hits_df) > 0: color_col = None for c in ("layer_id", "volume_id", "particle_id"): if c in hits_df.columns: color_col = c break fig.add_trace(go.Scatter3d( x=hits_df["x"], y=hits_df["y"], z=hits_df["z"], mode="markers", marker=dict( size=1.6, color=hits_df[color_col] if color_col else "royalblue", colorscale="Viridis", opacity=0.85, showscale=bool(color_col), colorbar=dict(title=color_col) if color_col else None, ), name=f"Tracker hits ({len(hits_df)})", hovertemplate="x=%{x:.1f}
y=%{y:.1f}
z=%{z:.1f}", )) # Reconstructed tracks — helical polylines. if tracks_df is not None and len(tracks_df) > 0: for _, track in tracks_df.head(30).iterrows(): try: x, y, z = _track_polyline(track.to_dict()) fig.add_trace(go.Scatter3d( x=x, y=y, z=z, mode="lines", line=dict(color="crimson", width=3), name="track", showlegend=False, hoverinfo="skip", )) except Exception: continue # Truth particles — momentum vectors from primary vertex. if particles_df is not None and len(particles_df) > 0: prim = particles_df if "primary" in prim.columns: prim = prim[prim["primary"] == True] prim = prim.head(20) for _, p in prim.iterrows(): try: px, py, pz = p.get("px", 0), p.get("py", 0), p.get("pz", 0) pmag = (px ** 2 + py ** 2 + pz ** 2) ** 0.5 if pmag < 1e-3: continue scale = 500.0 / pmag fig.add_trace(go.Scatter3d( x=[0, px * scale], y=[0, py * scale], z=[0, pz * scale], mode="lines", line=dict(color="gold", width=2, dash="dash"), name=f"truth pdg={int(p.get('pdg_id', 0))}", showlegend=False, hoverinfo="name", )) except Exception: continue fig.update_layout( title=f"{dataset_name} — event {event_id}", scene=dict( xaxis_title="x [mm]", yaxis_title="y [mm]", zaxis_title="z [mm]", aspectmode="data", bgcolor="rgb(10, 10, 25)", ), height=720, margin=dict(l=0, r=0, t=40, b=0), paper_bgcolor="rgb(10, 10, 25)", font=dict(color="white"), ) return fig def event_count(dataset_name): """How many cached events does this dataset have?""" tables = _load_dataset(dataset_name) hits = tables.get("tracker_hits") if hits is None or "event_id" not in hits.column_names: return EVENTS_PER_DATASET - 1 ids = np.asarray(hits.column("event_id")) return int(ids.max()) if len(ids) else 0 def on_dataset_change(dataset_name): max_evt = event_count(dataset_name) return gr.Slider(minimum=0, maximum=max(max_evt, 0), value=0, step=1, label="Event ID") with gr.Blocks( title="ColliderML Event Display", theme=gr.themes.Soft(primary_hue="blue"), css=".gradio-container {max-width: 1200px !important;}", ) as demo: gr.Markdown( """ # ColliderML Event Display Interactive 3D view of single events from the [ColliderML datasets](https://huggingface.co/datasets/CERN/ColliderML-Release-1). - **Blue points**: tracker hits (coloured by detector layer) - **Red lines**: reconstructed tracks (helical approximation) - **Yellow dashes**: truth particle momenta from the primary vertex """ ) with gr.Row(): dataset = gr.Dropdown( DATASETS, value=DATASETS[0], label="Dataset", scale=1, ) event_slider = gr.Slider( 0, EVENTS_PER_DATASET - 1, value=0, step=1, label="Event ID", scale=2, ) plot = gr.Plot() dataset.change( fn=on_dataset_change, inputs=dataset, outputs=event_slider, ) # Render on any change. for comp in (dataset, event_slider): comp.change( fn=render_event, inputs=[dataset, event_slider], outputs=plot, ) # Initial render on app load. demo.load( fn=render_event, inputs=[dataset, event_slider], outputs=plot, ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)