File size: 9,353 Bytes
734aa32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
"""
ColliderML Event Display

Interactive 3D visualisation of single events from ColliderML datasets.
Loads tracker hits, truth particles, and reconstructed tracks from a
chosen event and renders them with Plotly.

Runs as a Gradio HuggingFace Space.
"""

from functools import lru_cache
from pathlib import Path
from typing import Any, Dict

import gradio as gr
import numpy as np
import plotly.graph_objects as go
import pyarrow as pa
import pyarrow.parquet as pq

try:
    import colliderml

    HAS_COLLIDERML = True
except ImportError:
    HAS_COLLIDERML = False

# Datasets exposed in the UI. Kept short so the dropdown isn't overwhelming.
DATASETS = [
    "ttbar_pu0",
    "ttbar_pu200",
    "higgs_portal_pu10",
    "zmumu_pu0",
    "diphoton_pu0",
    "single_muon_pu0",
]

# How many events to cache per dataset on the Space host (limits memory).
EVENTS_PER_DATASET = 50

# Local cache directory (populated by cache_events.py at build time).
CACHE_DIR = Path(__file__).parent / "_cached_events"


def _frame_to_arrow(frame: Any) -> pa.Table:
    """Normalise a Polars frame (or already-pyarrow table) into a pyarrow Table.

    ``colliderml.load`` returns Polars frames in v0.4.0+; the rest of
    this Space is pyarrow-native because we want the row-level
    ``filter`` / ``to_pandas`` conveniences.
    """
    if isinstance(frame, pa.Table):
        return frame
    if hasattr(frame, "to_arrow"):
        collected = frame.collect() if hasattr(frame, "collect") else frame
        return collected.to_arrow()
    raise TypeError(f"Cannot convert {type(frame).__name__} to pyarrow.Table")


@lru_cache(maxsize=len(DATASETS))
def _load_dataset(name: str) -> Dict[str, pa.Table]:
    """Return a dict of pyarrow Tables for the given dataset, cached in memory.

    Tries the local cache first (pre-populated by cache_events.py), then
    falls back to ``colliderml.load()`` which downloads on demand.
    """
    local = CACHE_DIR / name
    if local.exists():
        tables: Dict[str, pa.Table] = {}
        for parquet in local.glob("*.parquet"):
            tables[parquet.stem] = pq.read_table(str(parquet))
        if tables:
            return tables

    if not HAS_COLLIDERML:
        return {}

    try:
        polars_frames = colliderml.load(
            name,
            tables=["tracker_hits", "particles", "tracks"],
            max_events=EVENTS_PER_DATASET,
        )
    except Exception as exc:
        print(f"Failed to load {name}: {exc}")
        return {}

    return {key: _frame_to_arrow(frame) for key, frame in polars_frames.items()}


def _filter_event(table, event_id):
    """Return rows where event_id == event_id, as a pandas DataFrame."""
    if table is None:
        return None
    col_names = set(table.column_names)
    if "event_id" not in col_names:
        return table.to_pandas()
    mask = np.asarray(table.column("event_id")) == event_id
    return table.filter(mask).to_pandas()


def _track_polyline(track, n_points=40):
    """Approximate a reconstructed track as a 3D polyline.

    Uses a simple helical extrapolation from the perigee parameters
    (d0, z0, phi, theta, qop). This is a visualisation aid, not physics.
    """
    d0 = track.get("d0", 0.0)
    z0 = track.get("z0", 0.0)
    phi = track.get("phi", 0.0)
    theta = track.get("theta", np.pi / 2)
    qop = track.get("qop", 1e-6)

    # Straight-line projection to ~1 m (good enough for a visual).
    s = np.linspace(0, 1000, n_points)  # mm
    x = -d0 * np.sin(phi) + s * np.sin(theta) * np.cos(phi)
    y = d0 * np.cos(phi) + s * np.sin(theta) * np.sin(phi)
    z = z0 + s * np.cos(theta)

    # Apply a small curvature based on qop (1/GeV momentum magnitude).
    if abs(qop) > 1e-9:
        curvature = float(qop) * 0.3  # rough mm^-1 at 1 T
        x += 0.5 * curvature * (s ** 2) * (-np.sin(phi))
        y += 0.5 * curvature * (s ** 2) * np.cos(phi)

    return x, y, z


def render_event(dataset_name, event_id):
    """Build the 3D figure for the selected dataset/event."""
    tables = _load_dataset(dataset_name)

    if not tables:
        fig = go.Figure()
        fig.update_layout(
            title=f"No data available for {dataset_name}",
            height=700,
        )
        return fig

    hits_df = _filter_event(tables.get("tracker_hits"), event_id)
    particles_df = _filter_event(tables.get("particles"), event_id)
    tracks_df = _filter_event(tables.get("tracks"), event_id)

    fig = go.Figure()

    # Tracker hits — main point cloud.
    if hits_df is not None and len(hits_df) > 0:
        color_col = None
        for c in ("layer_id", "volume_id", "particle_id"):
            if c in hits_df.columns:
                color_col = c
                break
        fig.add_trace(go.Scatter3d(
            x=hits_df["x"], y=hits_df["y"], z=hits_df["z"],
            mode="markers",
            marker=dict(
                size=1.6,
                color=hits_df[color_col] if color_col else "royalblue",
                colorscale="Viridis",
                opacity=0.85,
                showscale=bool(color_col),
                colorbar=dict(title=color_col) if color_col else None,
            ),
            name=f"Tracker hits ({len(hits_df)})",
            hovertemplate="x=%{x:.1f}<br>y=%{y:.1f}<br>z=%{z:.1f}<extra></extra>",
        ))

    # Reconstructed tracks — helical polylines.
    if tracks_df is not None and len(tracks_df) > 0:
        for _, track in tracks_df.head(30).iterrows():
            try:
                x, y, z = _track_polyline(track.to_dict())
                fig.add_trace(go.Scatter3d(
                    x=x, y=y, z=z,
                    mode="lines",
                    line=dict(color="crimson", width=3),
                    name="track",
                    showlegend=False,
                    hoverinfo="skip",
                ))
            except Exception:
                continue

    # Truth particles — momentum vectors from primary vertex.
    if particles_df is not None and len(particles_df) > 0:
        prim = particles_df
        if "primary" in prim.columns:
            prim = prim[prim["primary"] == True]
        prim = prim.head(20)
        for _, p in prim.iterrows():
            try:
                px, py, pz = p.get("px", 0), p.get("py", 0), p.get("pz", 0)
                pmag = (px ** 2 + py ** 2 + pz ** 2) ** 0.5
                if pmag < 1e-3:
                    continue
                scale = 500.0 / pmag
                fig.add_trace(go.Scatter3d(
                    x=[0, px * scale], y=[0, py * scale], z=[0, pz * scale],
                    mode="lines",
                    line=dict(color="gold", width=2, dash="dash"),
                    name=f"truth pdg={int(p.get('pdg_id', 0))}",
                    showlegend=False,
                    hoverinfo="name",
                ))
            except Exception:
                continue

    fig.update_layout(
        title=f"{dataset_name} — event {event_id}",
        scene=dict(
            xaxis_title="x [mm]",
            yaxis_title="y [mm]",
            zaxis_title="z [mm]",
            aspectmode="data",
            bgcolor="rgb(10, 10, 25)",
        ),
        height=720,
        margin=dict(l=0, r=0, t=40, b=0),
        paper_bgcolor="rgb(10, 10, 25)",
        font=dict(color="white"),
    )
    return fig


def event_count(dataset_name):
    """How many cached events does this dataset have?"""
    tables = _load_dataset(dataset_name)
    hits = tables.get("tracker_hits")
    if hits is None or "event_id" not in hits.column_names:
        return EVENTS_PER_DATASET - 1
    ids = np.asarray(hits.column("event_id"))
    return int(ids.max()) if len(ids) else 0


def on_dataset_change(dataset_name):
    max_evt = event_count(dataset_name)
    return gr.Slider(minimum=0, maximum=max(max_evt, 0), value=0, step=1, label="Event ID")


with gr.Blocks(
    title="ColliderML Event Display",
    theme=gr.themes.Soft(primary_hue="blue"),
    css=".gradio-container {max-width: 1200px !important;}",
) as demo:
    gr.Markdown(
        """
        # ColliderML Event Display

        Interactive 3D view of single events from the
        [ColliderML datasets](https://huggingface.co/datasets/CERN/ColliderML-Release-1).

        - **Blue points**: tracker hits (coloured by detector layer)
        - **Red lines**: reconstructed tracks (helical approximation)
        - **Yellow dashes**: truth particle momenta from the primary vertex
        """
    )
    with gr.Row():
        dataset = gr.Dropdown(
            DATASETS,
            value=DATASETS[0],
            label="Dataset",
            scale=1,
        )
        event_slider = gr.Slider(
            0, EVENTS_PER_DATASET - 1, value=0, step=1,
            label="Event ID",
            scale=2,
        )
    plot = gr.Plot()

    dataset.change(
        fn=on_dataset_change,
        inputs=dataset,
        outputs=event_slider,
    )

    # Render on any change.
    for comp in (dataset, event_slider):
        comp.change(
            fn=render_event,
            inputs=[dataset, event_slider],
            outputs=plot,
        )

    # Initial render on app load.
    demo.load(
        fn=render_event,
        inputs=[dataset, event_slider],
        outputs=plot,
    )


if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)