github-actions[bot] commited on
Commit
cc0720f
·
0 Parent(s):

Sync from GitHub f6dbbfb

Browse files
Files changed (47) hide show
  1. .github/workflows/sync-to-hf-space.yml +41 -0
  2. .gitignore +49 -0
  3. Dockerfile +15 -0
  4. README.md +79 -0
  5. app.py +772 -0
  6. config_files/config_hits_track_v4.yaml +146 -0
  7. scripts/evaluation.sh +26 -0
  8. scripts/train_clustering.sh +20 -0
  9. scripts/train_energy_pid.sh +24 -0
  10. src/data/config.py +218 -0
  11. src/data/fileio.py +101 -0
  12. src/data/preprocess.py +253 -0
  13. src/data/tools.py +191 -0
  14. src/dataset/dataclasses.py +126 -0
  15. src/dataset/dataset.py +287 -0
  16. src/dataset/functions_data.py +26 -0
  17. src/dataset/functions_graph.py +105 -0
  18. src/dataset/functions_particles.py +122 -0
  19. src/inference.py +735 -0
  20. src/layers/clustering.py +99 -0
  21. src/layers/inference_oc.py +251 -0
  22. src/layers/object_cond.py +609 -0
  23. src/layers/regression/loss_regression.py +59 -0
  24. src/layers/shower_dataframe.py +441 -0
  25. src/layers/shower_matching.py +127 -0
  26. src/layers/tools_for_regression.py +131 -0
  27. src/layers/utils_training.py +166 -0
  28. src/models/E_correction_module.py +43 -0
  29. src/models/Gatr_pf_e_noise.py +332 -0
  30. src/models/energy_correction_NN.py +299 -0
  31. src/models/energy_correction_charged.py +116 -0
  32. src/models/energy_correction_neutral.py +157 -0
  33. src/models/wrapper/example_mode_gatr_noise.py +21 -0
  34. src/train_lightning1.py +128 -0
  35. src/utils/callbacks.py +30 -0
  36. src/utils/import_tools.py +8 -0
  37. src/utils/inference/pandas_helpers.py +36 -0
  38. src/utils/load_pretrained_models.py +32 -0
  39. src/utils/logger_wandb.py +33 -0
  40. src/utils/parser_args.py +246 -0
  41. src/utils/pid_conversion.py +7 -0
  42. src/utils/post_clustering_features.py +82 -0
  43. src/utils/train_utils.py +281 -0
  44. tests/test_cpu_attention.py +99 -0
  45. tests/test_csv_priority.py +162 -0
  46. tests/test_energy_correction_no_matches.py +90 -0
  47. tests/test_pfo_links.py +231 -0
.github/workflows/sync-to-hf-space.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face Space
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ permissions:
9
+ contents: read
10
+
11
+ jobs:
12
+ sync-to-hf:
13
+ runs-on: ubuntu-latest
14
+ steps:
15
+ - name: Checkout repo (no history)
16
+ uses: actions/checkout@v4
17
+ with:
18
+ fetch-depth: 1
19
+ lfs: false
20
+
21
+ - name: Push to Hugging Face Space
22
+ env:
23
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
24
+ run: |
25
+ # Configure git
26
+ git config --global user.email "github-actions[bot]@users.noreply.github.com"
27
+ git config --global user.name "github-actions[bot]"
28
+
29
+ # Use a credential helper to avoid embedding the token in the URL
30
+ git config --global credential.helper store
31
+ printf 'https://user:%s@huggingface.co\n' "$HF_TOKEN" > ~/.git-credentials
32
+
33
+ # Create a fresh repo with a single commit (no history)
34
+ cd $GITHUB_WORKSPACE
35
+ rm -rf .git
36
+ git init --initial-branch main
37
+ git add .
38
+ git commit -m "Sync from GitHub ${GITHUB_SHA::7}"
39
+
40
+ # Force-push the single commit to HF Space
41
+ git push --force https://huggingface.co/spaces/gregorkrzmanc/HitPF_demo main
.gitignore ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.pyo
5
+ *.pyd
6
+ *.egg-info/
7
+ dist/
8
+ build/
9
+ .eggs/
10
+
11
+ # Jupyter
12
+ .ipynb_checkpoints/
13
+ *.ipynb
14
+
15
+ # Weights & Biases
16
+ wandb/
17
+
18
+ # Model checkpoints and outputs
19
+ *.pt
20
+ *.pth
21
+ showers_df_evaluation/
22
+
23
+ # Data files
24
+ *.root
25
+ *.h5
26
+ *.hdf5
27
+ *.pkl
28
+ *.pickle
29
+ *.npy
30
+ *.npz
31
+
32
+ # Demo files are downloaded at runtime from Hugging Face Hub
33
+ model_clustering.ckpt
34
+ model_e_pid.ckpt
35
+ test_data.parquet
36
+
37
+ # Logs
38
+ *.log
39
+ logs/
40
+
41
+ # Editors
42
+ .vscode/
43
+ .idea/
44
+ *.swp
45
+ *.swo
46
+ *~
47
+
48
+ # OS
49
+ .DS_Store
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM dologarcia/gatr:v9
2
+
3
+ WORKDIR /app
4
+
5
+ RUN pip install --no-cache-dir \
6
+ densitypeakclustering \
7
+ lightning-utilities \
8
+ torchmetrics \
9
+ gradio \
10
+ plotly
11
+
12
+ COPY . .
13
+ EXPOSE 7860
14
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
15
+ CMD ["python", "app.py"]
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: HitPF
3
+ emoji: ⚛️
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+
11
+ # HitPF
12
+
13
+ **HitPF** is a GATr-based particle-flow reconstruction model for the CLD detector at the FCC-ee.
14
+ It performs two sequential tasks:
15
+
16
+ 1. **Clustering** — groups calorimeter hits and tracks into particle-flow objects using an object-condensation loss.
17
+ 2. **Property regression** — regresses a correction factor for each reconstructed cluster using a GNN-based model and a PID class
18
+
19
+ ---
20
+
21
+ ## Dependencies
22
+
23
+ The code can be used with this container:
24
+ ```docker://dologarcia/gatr:v9```
25
+
26
+ For the live demo, gradio and plotly also need to be installed:
27
+ ```
28
+ pip install gradio plotly
29
+ ```
30
+
31
+ ---
32
+
33
+ ## Dataset
34
+
35
+ Input data is stored as `.parquet` files, each file stores 100 events. A sample of the dataset in ML-ready format can be found at [1](https://zenodo.org/records/18749298). The full dataset is hosted on CERN's EOS space.
36
+
37
+
38
+ ---
39
+
40
+ ## Training
41
+
42
+ ### Step 1 — Clustering
43
+
44
+ ```bash
45
+ bash scripts/train_clustering.sh
46
+ ```
47
+
48
+
49
+ ### Step 2 — Energy correction
50
+
51
+
52
+ ```bash
53
+ bash scripts/train_energy_pid.sh
54
+ ```
55
+
56
+ ### Validation
57
+
58
+ ```bash
59
+ bash scripts/evaluation.sh
60
+ ```
61
+
62
+ ---
63
+ ### Live demo (work in progress)
64
+
65
+ ```bash
66
+ python -m app
67
+ ```
68
+
69
+ ## Citation
70
+
71
+ If you use this code, please cite:
72
+
73
+ ```bibtex
74
+ @software{hitpf2026,
75
+ title = {End-to-end event reconstruction for precision physics at future colliders code},
76
+ year = {2026},
77
+ url = {https://github.com/mgarciam/HitPF}
78
+ }
79
+ ```
app.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Gradio UI for single-event MLPF inference.
4
+
5
+ Launch with:
6
+ python app.py [--device cpu]
7
+
8
+ The UI lets you:
9
+ 1. Load an event from a parquet file (pick file + event index), **or**
10
+ paste hit / track / particle data in CSV format.
11
+ 2. (Optionally) load pre-trained model checkpoints.
12
+ 3. Run inference → view predicted particles and the hit→cluster mapping.
13
+ """
14
+
15
+ import argparse
16
+ import os
17
+ import shutil
18
+ import traceback
19
+
20
+ import gradio as gr
21
+ import pandas as pd
22
+ import numpy as np
23
+ import plotly.graph_objects as go
24
+ from huggingface_hub import hf_hub_download
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Auto-download demo files from Hugging Face Hub if they are not present
28
+ # ---------------------------------------------------------------------------
29
+
30
+ _HF_REPO_ID = "gregorkrzmanc/hitpf_demo_files"
31
+ _DEMO_FILES = [
32
+ "model_clustering.ckpt",
33
+ "model_e_pid.ckpt",
34
+ "test_data.parquet",
35
+ ]
36
+
37
+
38
+ def _ensure_demo_files(dest_dir: str = ".") -> None:
39
+ """Download demo files from Hugging Face Hub if they don't already exist."""
40
+ for fname in _DEMO_FILES:
41
+ dest = os.path.join(dest_dir, fname)
42
+ if not os.path.isfile(dest):
43
+ try:
44
+ print(f"Downloading {fname} from HF Hub ({_HF_REPO_ID}) …")
45
+ downloaded = hf_hub_download(
46
+ repo_id=_HF_REPO_ID,
47
+ filename=fname,
48
+ repo_type="dataset",
49
+ )
50
+ shutil.copy(downloaded, dest)
51
+ print(f" → saved to {dest}")
52
+ except Exception as exc:
53
+ print(f" ⚠️ Could not download {fname}: {exc}")
54
+
55
+
56
+ _ensure_demo_files()
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # Global state – filled lazily
60
+ # ---------------------------------------------------------------------------
61
+ _MODEL = None
62
+ _ARGS = None
63
+ _DEVICE = "cpu"
64
+
65
+
66
+ def _set_device(device: str):
67
+ global _DEVICE
68
+ _DEVICE = device
69
+
70
+
71
+ # ---------------------------------------------------------------------------
72
+ # Model loading
73
+ # ---------------------------------------------------------------------------
74
+
75
+ def load_model_ui(clustering_ckpt: str, energy_pid_ckpt: str, device: str):
76
+ """Load model from checkpoint paths (called by the UI button)."""
77
+ global _MODEL, _ARGS, _DEVICE
78
+ _DEVICE = device or "cpu"
79
+
80
+ if not clustering_ckpt or not os.path.isfile(clustering_ckpt):
81
+ return "⚠️ Please provide a valid path to the clustering checkpoint."
82
+
83
+ energy_pid = energy_pid_ckpt if (energy_pid_ckpt and os.path.isfile(energy_pid_ckpt)) else None
84
+
85
+ try:
86
+ from src.inference import load_model
87
+ _MODEL, _ARGS = load_model(
88
+ clustering_ckpt=clustering_ckpt,
89
+ energy_pid_ckpt=energy_pid,
90
+ device=_DEVICE,
91
+ )
92
+ msg = f"✅ Model loaded on **{_DEVICE}**"
93
+ if energy_pid:
94
+ msg += " (clustering + energy/PID correction)"
95
+ else:
96
+ msg += " (clustering only — no energy/PID correction)"
97
+ return msg
98
+ except Exception:
99
+ return f"❌ Failed to load model:\n```\n{traceback.format_exc()}\n```"
100
+
101
+
102
+ # ---------------------------------------------------------------------------
103
+ # Event loading helpers
104
+ # ---------------------------------------------------------------------------
105
+
106
+ def _count_events_in_parquet(parquet_path: str) -> str:
107
+ """Return a short info string about the parquet file."""
108
+ if not parquet_path or not os.path.isfile(parquet_path):
109
+ return "No file selected"
110
+ try:
111
+ from src.inference import load_event_from_parquet
112
+ from src.data.fileio import _read_parquet
113
+ table = _read_parquet(parquet_path)
114
+ n = len(table["X_track"])
115
+ return f"File has **{n}** events (indices 0–{n-1})"
116
+ except Exception as e:
117
+ return f"Error reading file: {e}"
118
+
119
+
120
+ def _load_event_into_csv(parquet_path: str, event_index: int):
121
+ """Load an event from a parquet file and return CSV strings for the text fields."""
122
+ if not parquet_path or not os.path.isfile(parquet_path):
123
+ return "", "", "", "", "", "⚠️ Please provide a valid parquet file path."
124
+ try:
125
+ from src.inference import load_event_from_parquet
126
+ event = load_event_from_parquet(parquet_path, int(event_index))
127
+
128
+ hits_arr = np.asarray(event.get("X_hit", []))
129
+ tracks_arr = np.asarray(event.get("X_track", []))
130
+ particles_arr = np.asarray(event.get("X_gen", []))
131
+ pandora_arr = np.asarray(event.get("X_pandora", []))
132
+
133
+ def _arr_to_csv(arr):
134
+ if arr.ndim != 2:
135
+ return ""
136
+ return "\n".join(",".join(str(v) for v in row) for row in arr)
137
+
138
+ def _1d_to_csv(arr):
139
+ if len(arr) == 0:
140
+ return ""
141
+ return ",".join(str(int(v)) for v in arr)
142
+
143
+ pfo_calohit = np.asarray(event.get("pfo_calohit", []), dtype=np.int64)
144
+ pfo_track = np.asarray(event.get("pfo_track", []), dtype=np.int64)
145
+ calohit_csv = _1d_to_csv(pfo_calohit)
146
+ track_csv = _1d_to_csv(pfo_track)
147
+ if calohit_csv and track_csv:
148
+ pfo_links_csv = calohit_csv + "\n" + track_csv
149
+ elif calohit_csv:
150
+ pfo_links_csv = calohit_csv
151
+ elif track_csv:
152
+ pfo_links_csv = "\n" + track_csv
153
+ else:
154
+ pfo_links_csv = ""
155
+
156
+ return (
157
+ _arr_to_csv(hits_arr),
158
+ _arr_to_csv(tracks_arr),
159
+ _arr_to_csv(particles_arr),
160
+ _arr_to_csv(pandora_arr),
161
+ pfo_links_csv,
162
+ f"✅ Loaded event **{int(event_index)}**: "
163
+ f"{hits_arr.shape[0] if hits_arr.ndim == 2 else 0} hits, "
164
+ f"{tracks_arr.shape[0] if tracks_arr.ndim == 2 else 0} tracks, "
165
+ f"{particles_arr.shape[0] if particles_arr.ndim == 2 else 0} MC particles, "
166
+ f"{pandora_arr.shape[0] if pandora_arr.ndim == 2 else 0} Pandora PFOs",
167
+ )
168
+ except Exception as e:
169
+ return "", "", "", "", "", f"❌ Error loading event: {e}"
170
+
171
+
172
+ def _build_cluster_plot(hit_cluster_df: pd.DataFrame) -> go.Figure:
173
+ """Build an interactive 3D scatter plot of hits colored by cluster ID."""
174
+ if hit_cluster_df.empty:
175
+ fig = go.Figure()
176
+ fig.update_layout(title="No hit data available", height=600)
177
+ return fig
178
+
179
+ df = hit_cluster_df.copy()
180
+
181
+ # Drop rows with NaN/Inf coordinates
182
+ for col in ("x", "y", "z", "hit_energy"):
183
+ df[col] = pd.to_numeric(df[col], errors="coerce")
184
+ df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["x", "y", "z", "hit_energy"])
185
+ if df.empty:
186
+ fig = go.Figure()
187
+ fig.update_layout(title="No valid hit data (all NaN/Inf)", height=600)
188
+ return fig
189
+
190
+ # Normalize hit energies for marker sizes
191
+ energies = df["hit_energy"].values.astype(float)
192
+ e_min, e_max = float(energies.min()), float(energies.max())
193
+ if e_max > e_min:
194
+ norm_e = (energies - e_min) / (e_max - e_min)
195
+ else:
196
+ norm_e = np.ones_like(energies) * 0.5 # midpoint when all equal
197
+ marker_sizes = 3 + norm_e * 12 # min size 3, max size 15
198
+
199
+ # Build per-hit hover text (avoids mixed-type customdata serialization issues)
200
+ df["_hover"] = (
201
+ "<b>" + df["hit_type"].astype(str) + "</b> hit #" + df["hit_index"].astype(int).astype(str) + "<br>"
202
+ + "Cluster: " + df["cluster_id"].astype(int).astype(str) + "<br>"
203
+ + "Energy: " + df["hit_energy"].map(lambda v: f"{v:.4f}") + "<br>"
204
+ + "x: " + df["x"].map(lambda v: f"{v:.2f}")
205
+ + ", y: " + df["y"].map(lambda v: f"{v:.2f}")
206
+ + ", z: " + df["z"].map(lambda v: f"{v:.2f}")
207
+ )
208
+
209
+ cluster_ids = df["cluster_id"].values
210
+ unique_clusters = sorted(set(int(c) for c in cluster_ids))
211
+
212
+ fig = go.Figure()
213
+ for cid in unique_clusters:
214
+ mask = cluster_ids == cid
215
+ subset = df[mask]
216
+ sizes = marker_sizes[mask].tolist()
217
+ label = "noise" if cid == 0 else f"cluster {cid}"
218
+ fig.add_trace(go.Scatter3d(
219
+ x=subset["x"].tolist(),
220
+ y=subset["y"].tolist(),
221
+ z=subset["z"].tolist(),
222
+ mode="markers",
223
+ name=label,
224
+ marker=dict(size=sizes, opacity=0.8),
225
+ hovertext=subset["_hover"].tolist(),
226
+ hoverinfo="text",
227
+ ))
228
+
229
+ fig.update_layout(
230
+ title="Hit → Cluster 3D Map",
231
+ scene=dict(xaxis_title="x", yaxis_title="y", zaxis_title="z"),
232
+ legend_title="Cluster",
233
+ height=600,
234
+ margin=dict(l=0, r=0, t=40, b=0),
235
+ )
236
+ return fig
237
+
238
+
239
+ def _build_pandora_cluster_plot(hit_cluster_df: pd.DataFrame) -> go.Figure:
240
+ """Build an interactive 3D scatter plot of hits colored by Pandora cluster ID."""
241
+ if hit_cluster_df.empty or "pandora_cluster_id" not in hit_cluster_df.columns:
242
+ fig = go.Figure()
243
+ fig.update_layout(title="No Pandora cluster data available", height=600)
244
+ return fig
245
+
246
+ df = hit_cluster_df.copy()
247
+
248
+ # Only keep rows that have valid Pandora assignments (pandora_cluster_id >= 0)
249
+ for col in ("x", "y", "z", "hit_energy"):
250
+ df[col] = pd.to_numeric(df[col], errors="coerce")
251
+ df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["x", "y", "z", "hit_energy"])
252
+ if df.empty:
253
+ fig = go.Figure()
254
+ fig.update_layout(title="No valid hit data for Pandora plot (all NaN/Inf)", height=600)
255
+ return fig
256
+
257
+ # Normalize hit energies for marker sizes
258
+ energies = df["hit_energy"].values.astype(float)
259
+ e_min, e_max = float(energies.min()), float(energies.max())
260
+ if e_max > e_min:
261
+ norm_e = (energies - e_min) / (e_max - e_min)
262
+ else:
263
+ norm_e = np.ones_like(energies) * 0.5
264
+ marker_sizes = 3 + norm_e * 12
265
+
266
+ # Build per-hit hover text
267
+ df["_hover"] = (
268
+ "<b>" + df["hit_type"].astype(str) + "</b> hit #" + df["hit_index"].astype(int).astype(str) + "<br>"
269
+ + "Pandora cluster: " + df["pandora_cluster_id"].astype(int).astype(str) + "<br>"
270
+ + "Energy: " + df["hit_energy"].map(lambda v: f"{v:.4f}") + "<br>"
271
+ + "x: " + df["x"].map(lambda v: f"{v:.2f}")
272
+ + ", y: " + df["y"].map(lambda v: f"{v:.2f}")
273
+ + ", z: " + df["z"].map(lambda v: f"{v:.2f}")
274
+ )
275
+
276
+ pandora_ids = df["pandora_cluster_id"].values
277
+ unique_clusters = sorted(set(int(c) for c in pandora_ids))
278
+
279
+ fig = go.Figure()
280
+ for cid in unique_clusters:
281
+ mask = pandora_ids == cid
282
+ subset = df[mask]
283
+ sizes = marker_sizes[mask].tolist()
284
+ label = "unassigned" if cid == -1 else f"PFO {cid}"
285
+ fig.add_trace(go.Scatter3d(
286
+ x=subset["x"].tolist(),
287
+ y=subset["y"].tolist(),
288
+ z=subset["z"].tolist(),
289
+ mode="markers",
290
+ name=label,
291
+ marker=dict(size=sizes, opacity=0.8),
292
+ hovertext=subset["_hover"].tolist(),
293
+ hoverinfo="text",
294
+ ))
295
+
296
+ fig.update_layout(
297
+ title="Hit → Pandora Cluster 3D Map",
298
+ scene=dict(xaxis_title="x", yaxis_title="y", zaxis_title="z"),
299
+ legend_title="Pandora PFO",
300
+ height=600,
301
+ margin=dict(l=0, r=0, t=40, b=0),
302
+ )
303
+ return fig
304
+
305
+
306
+ def _build_clustering_space_plot(hit_cluster_df: pd.DataFrame) -> go.Figure:
307
+ """Build an interactive 3D scatter plot of hits in the learned clustering space."""
308
+ if hit_cluster_df.empty or "cluster_x" not in hit_cluster_df.columns:
309
+ fig = go.Figure()
310
+ fig.update_layout(title="No clustering-space data available", height=600)
311
+ return fig
312
+
313
+ df = hit_cluster_df.copy()
314
+
315
+ # Drop rows with NaN/Inf coordinates
316
+ for col in ("cluster_x", "cluster_y", "cluster_z", "hit_energy"):
317
+ df[col] = pd.to_numeric(df[col], errors="coerce")
318
+ df = df.replace([np.inf, -np.inf], np.nan).dropna(
319
+ subset=["cluster_x", "cluster_y", "cluster_z", "hit_energy"]
320
+ )
321
+ if df.empty:
322
+ fig = go.Figure()
323
+ fig.update_layout(title="No valid clustering-space data (all NaN/Inf)", height=600)
324
+ return fig
325
+
326
+ # Normalize hit energies for marker sizes
327
+ energies = df["hit_energy"].values.astype(float)
328
+ e_min, e_max = float(energies.min()), float(energies.max())
329
+ if e_max > e_min:
330
+ norm_e = (energies - e_min) / (e_max - e_min)
331
+ else:
332
+ norm_e = np.ones_like(energies) * 0.5
333
+ marker_sizes = 3 + norm_e * 12
334
+
335
+ # Build per-hit hover text
336
+ df["_hover"] = (
337
+ "<b>" + df["hit_type"].astype(str) + "</b> hit #" + df["hit_index"].astype(int).astype(str) + "<br>"
338
+ + "Cluster: " + df["cluster_id"].astype(int).astype(str) + "<br>"
339
+ + "Energy: " + df["hit_energy"].map(lambda v: f"{v:.4f}") + "<br>"
340
+ + "cluster_x: " + df["cluster_x"].map(lambda v: f"{v:.4f}")
341
+ + ", cluster_y: " + df["cluster_y"].map(lambda v: f"{v:.4f}")
342
+ + ", cluster_z: " + df["cluster_z"].map(lambda v: f"{v:.4f}")
343
+ )
344
+
345
+ cluster_ids = df["cluster_id"].values
346
+ unique_clusters = sorted(set(int(c) for c in cluster_ids))
347
+
348
+ fig = go.Figure()
349
+ for cid in unique_clusters:
350
+ mask = cluster_ids == cid
351
+ subset = df[mask]
352
+ sizes = marker_sizes[mask].tolist()
353
+ label = "noise" if cid == 0 else f"cluster {cid}"
354
+ fig.add_trace(go.Scatter3d(
355
+ x=subset["cluster_x"].tolist(),
356
+ y=subset["cluster_y"].tolist(),
357
+ z=subset["cluster_z"].tolist(),
358
+ mode="markers",
359
+ name=label,
360
+ marker=dict(size=sizes, opacity=0.8),
361
+ hovertext=subset["_hover"].tolist(),
362
+ hoverinfo="text",
363
+ ))
364
+
365
+ fig.update_layout(
366
+ title="Clustering Space 3D Map (GATr regressed coordinates)",
367
+ scene=dict(
368
+ xaxis_title="cluster_x",
369
+ yaxis_title="cluster_y",
370
+ zaxis_title="cluster_z",
371
+ ),
372
+ legend_title="Cluster",
373
+ height=600,
374
+ margin=dict(l=0, r=0, t=40, b=0),
375
+ )
376
+ return fig
377
+
378
+
379
+ # ---------------------------------------------------------------------------
380
+ # Main inference entry point for the UI
381
+ # ---------------------------------------------------------------------------
382
+
383
+ def _compute_inv_mass(df, e_col, px_col, py_col, pz_col):
384
+ """Compute the invariant mass of a system of particles in GeV.
385
+
386
+ Returns the scalar invariant mass m = sqrt(max((ΣE)²−(Σpx)²−(Σpy)²−(Σpz)², 0)),
387
+ or *None* when *df* is empty or the required columns are absent.
388
+ """
389
+ if df.empty:
390
+ return None
391
+ for col in (e_col, px_col, py_col, pz_col):
392
+ if col not in df.columns:
393
+ return None
394
+ E = float(df[e_col].sum())
395
+ px = float(df[px_col].sum())
396
+ py = float(df[py_col].sum())
397
+ pz = float(df[pz_col].sum())
398
+ m2 = E ** 2 - px ** 2 - py ** 2 - pz ** 2
399
+ return float(np.sqrt(max(m2, 0.0)))
400
+
401
+
402
+ def _fmt_mass(val):
403
+ """Format an invariant-mass value (float or None) as a GeV string."""
404
+ return f"{val:.4f} GeV" if val is not None else "N/A"
405
+
406
+
407
+ def run_inference_ui(
408
+ parquet_path: str,
409
+ event_index: int,
410
+ csv_hits: str,
411
+ csv_tracks: str,
412
+ csv_particles: str,
413
+ csv_pandora: str,
414
+ csv_pfo_links: str = "",
415
+ ):
416
+ """Run inference on a single event, return predicted particles, 3D plots, MC particles and Pandora particles.
417
+
418
+ Returns
419
+ -------
420
+ particles_df : pandas.DataFrame
421
+ cluster_fig : plotly.graph_objects.Figure
422
+ clustering_space_fig : plotly.graph_objects.Figure
423
+ pandora_cluster_fig : plotly.graph_objects.Figure
424
+ mc_particles_df : pandas.DataFrame
425
+ pandora_particles_df : pandas.DataFrame
426
+ inv_mass_summary : str
427
+ """
428
+ global _MODEL, _ARGS, _DEVICE
429
+
430
+ empty_fig = go.Figure()
431
+
432
+ if _MODEL is None:
433
+ return (
434
+ pd.DataFrame({"error": ["Model not loaded. Please load a model first."]}),
435
+ empty_fig,
436
+ empty_fig,
437
+ empty_fig,
438
+ pd.DataFrame(),
439
+ pd.DataFrame(),
440
+ "",
441
+ )
442
+
443
+ try:
444
+ from src.inference import load_event_from_parquet, run_single_event_inference
445
+
446
+ # Decide input source
447
+ use_parquet = parquet_path and os.path.isfile(parquet_path)
448
+ use_csv = bool(csv_hits and csv_hits.strip())
449
+
450
+ if not use_parquet and not use_csv:
451
+ return (
452
+ pd.DataFrame({"error": ["Provide a parquet file or paste CSV hit data."]}),
453
+ empty_fig,
454
+ empty_fig,
455
+ empty_fig,
456
+ pd.DataFrame(),
457
+ pd.DataFrame(),
458
+ "",
459
+ )
460
+
461
+ if use_csv:
462
+ event = _parse_csv_event(csv_hits, csv_tracks, csv_particles, csv_pandora, csv_pfo_links)
463
+ elif use_parquet:
464
+ event = load_event_from_parquet(parquet_path, int(event_index))
465
+
466
+ particles_df, hit_cluster_df, mc_particles_df, pandora_particles_df = run_single_event_inference(
467
+ event, _MODEL, _ARGS, device=_DEVICE,
468
+ )
469
+ if particles_df.empty:
470
+ particles_df = pd.DataFrame({"info": ["Event produced no clusters (empty graph)."]})
471
+
472
+ cluster_fig = _build_cluster_plot(hit_cluster_df)
473
+ clustering_space_fig = _build_clustering_space_plot(hit_cluster_df)
474
+ pandora_cluster_fig = _build_pandora_cluster_plot(hit_cluster_df)
475
+
476
+ # Compute invariant masses [GeV]
477
+ m_true = _compute_inv_mass(mc_particles_df, "energy", "px", "py", "pz")
478
+ # HitPF uses corrected_energy when available, otherwise energy_sum_hits
479
+ hitpf_e_col = "corrected_energy" if "corrected_energy" in particles_df.columns else "energy_sum_hits"
480
+ m_reco_hitpf = _compute_inv_mass(particles_df, hitpf_e_col, "px", "py", "pz")
481
+ m_reco_pandora = _compute_inv_mass(pandora_particles_df, "energy", "px", "py", "pz")
482
+
483
+ inv_mass_summary = (
484
+ f"**Invariant mass (sum of all particle 4-vectors)**\n\n"
485
+ f"| Algorithm | m [GeV] |\n"
486
+ f"|---|---|\n"
487
+ f"| m_true (MC truth) | {_fmt_mass(m_true)} |\n"
488
+ f"| m_reco (HitPF) | {_fmt_mass(m_reco_hitpf)} |\n"
489
+ f"| m_reco (Pandora) | {_fmt_mass(m_reco_pandora)} |"
490
+ )
491
+
492
+ return particles_df, cluster_fig, clustering_space_fig, pandora_cluster_fig, mc_particles_df, pandora_particles_df, inv_mass_summary
493
+
494
+ except Exception:
495
+ err = traceback.format_exc()
496
+ return (
497
+ pd.DataFrame({"error": [err]}),
498
+ empty_fig,
499
+ empty_fig,
500
+ empty_fig,
501
+ pd.DataFrame(),
502
+ pd.DataFrame(),
503
+ "",
504
+ )
505
+
506
+
507
+ def _parse_csv_event(csv_hits: str, csv_tracks: str, csv_particles: str, csv_pandora: str = "", csv_pfo_links: str = ""):
508
+ """Parse user-provided CSV text into the dict-of-arrays format expected by
509
+ ``create_graph``.
510
+
511
+ Expected CSV columns for hits (X_hit) — 11 columns:
512
+ 0: hit_x — hit position x [mm]
513
+ 1: hit_y — hit position y [mm]
514
+ 2: hit_z — hit position z [mm]
515
+ 3: hit_px — hit momentum px [GeV] (0 for calo hits)
516
+ 4: hit_py — hit momentum py [GeV] (0 for calo hits)
517
+ 5: hit_energy — hit energy deposit [GeV]
518
+ 6: hit_x_calo — hit position x at calorimeter surface [mm] (used as 3D position by the model)
519
+ 7: hit_y_calo — hit position y at calorimeter surface [mm]
520
+ 8: hit_z_calo — hit position z at calorimeter surface [mm]
521
+ 9: (unused) — reserved column (set to 0)
522
+ 10: hit_type — hit sub-detector type: 1 = ECAL, 2 = HCAL, 3 = muon system
523
+
524
+ Expected CSV columns for tracks (X_track) — 25 columns (padded with
525
+ zeros if fewer are provided; minimum 17):
526
+ 0: elemtype — element type (always 1 for tracks)
527
+ 1–4: (unused) — reserved columns (set to 0)
528
+ 5: p — track momentum magnitude |p| [GeV]
529
+ 6: px_IP — track px at interaction point [GeV]
530
+ 7: py_IP — track py at interaction point [GeV]
531
+ 8: pz_IP — track pz at interaction point [GeV]
532
+ 9–11: (unused) — reserved columns (set to 0)
533
+ 12: ref_x_calo — track reference-point x at calorimeter [mm]
534
+ 13: ref_y_calo — track reference-point y at calorimeter [mm]
535
+ 14: ref_z_calo — track reference-point z at calorimeter [mm]
536
+ 15: chi2 — track-fit chi-squared
537
+ 16: ndf — track-fit number of degrees of freedom
538
+ 17–21: (unused) — reserved columns (set to 0)
539
+ 22: px_calo — track momentum x component at calorimeter [GeV]
540
+ 23: py_calo — track momentum y component at calorimeter [GeV]
541
+ 24: pz_calo — track momentum z component at calorimeter [GeV]
542
+
543
+ Expected CSV columns for particles / MC truth (X_gen) — 18 columns:
544
+ 0: pid — PDG particle ID (e.g. 211, 22, 11, 13)
545
+ 1: gen_status — generator status code
546
+ 2: isDecayedInCalo — 1 if decayed in calorimeter, else 0
547
+ 3: isDecayedInTracker — 1 if decayed in tracker, else 0
548
+ 4: theta — polar angle [rad]
549
+ 5: phi — azimuthal angle [rad]
550
+ 6: (unused) — reserved (set to 0)
551
+ 7: (unused) — reserved (set to 0)
552
+ 8: energy — true particle energy [GeV]
553
+ 9: (unused) — reserved (set to 0)
554
+ 10: mass — particle mass [GeV]
555
+ 11: momentum — momentum magnitude |p| [GeV]
556
+ 12: px — momentum x component [GeV]
557
+ 13: py — momentum y component [GeV]
558
+ 14: pz — momentum z component [GeV]
559
+ 15: vx — production vertex x [mm]
560
+ 16: vy — production vertex y [mm]
561
+ 17: vz — production vertex z [mm]
562
+
563
+ PFO links (csv_pfo_links) — two lines of comma-separated integers:
564
+ Line 1: pfo_calohit — one PFO index per calorimeter hit (-1 = unassigned)
565
+ Line 2: pfo_track — one PFO index per track (-1 = unassigned)
566
+ """
567
+ import io
568
+ import awkward as ak
569
+
570
+ def _read(text, min_cols=1):
571
+ if not text or not text.strip():
572
+ return np.zeros((0, min_cols), dtype=np.float64)
573
+ df = pd.read_csv(io.StringIO(text), header=None)
574
+ return df.values.astype(np.float64)
575
+
576
+ hits_arr = _read(csv_hits, 11)
577
+ tracks_arr = _read(csv_tracks, 25)
578
+ particles_arr = _read(csv_particles, 18)
579
+ pandora_arr = _read(csv_pandora, 9)
580
+
581
+ # Pad tracks to 25 columns if needed
582
+ if tracks_arr.shape[1] < 25 and tracks_arr.shape[0] > 0:
583
+ pad = np.zeros((tracks_arr.shape[0], 25 - tracks_arr.shape[1]))
584
+ tracks_arr = np.concatenate([tracks_arr, pad], axis=1)
585
+
586
+ # Build ygen_hit / ygen_track (particle link per hit — use -1 for unknown)
587
+ ygen_hit = np.full(len(hits_arr), -1, dtype=np.int64)
588
+ ygen_track = np.full(len(tracks_arr), -1, dtype=np.int64)
589
+
590
+ # Parse PFO link arrays (hit → Pandora cluster mapping)
591
+ pfo_calohit = np.array([], dtype=np.int64)
592
+ pfo_track = np.array([], dtype=np.int64)
593
+ if csv_pfo_links and csv_pfo_links.strip():
594
+ lines = csv_pfo_links.strip().split("\n")
595
+ if len(lines) >= 1 and lines[0].strip():
596
+ pfo_calohit = np.array(
597
+ [int(v) for v in lines[0].strip().split(",")], dtype=np.int64
598
+ )
599
+ if len(lines) >= 2 and lines[1].strip():
600
+ pfo_track = np.array(
601
+ [int(v) for v in lines[1].strip().split(",")], dtype=np.int64
602
+ )
603
+
604
+ event = {
605
+ "X_hit": hits_arr,
606
+ "X_track": tracks_arr,
607
+ "X_gen": particles_arr,
608
+ "X_pandora": pandora_arr,
609
+ "ygen_hit": ygen_hit,
610
+ "ygen_track": ygen_track,
611
+ "pfo_calohit": pfo_calohit,
612
+ "pfo_track": pfo_track,
613
+ }
614
+ return event
615
+
616
+
617
+ # ---------------------------------------------------------------------------
618
+ # Build the Gradio interface
619
+ # ---------------------------------------------------------------------------
620
+
621
+ def build_app():
622
+ with gr.Blocks(title="HitPF — Single-event MLPF Inference") as demo:
623
+ gr.Markdown(
624
+ "# HitPF — Single-event MLPF Inference\n"
625
+ "Run the GATr-based particle-flow reconstruction on a single event.\n\n"
626
+ "**Steps:** 1) Load model checkpoints 2) Select an event 3) Run inference"
627
+ )
628
+
629
+ # ---- Model loading ----
630
+ with gr.Accordion("1 · Load Model", open=True):
631
+ with gr.Row():
632
+ clustering_ckpt = gr.Textbox(
633
+ label="Clustering checkpoint (.ckpt)",
634
+ value="model_clustering.ckpt",
635
+ placeholder="/path/to/clustering.ckpt",
636
+ )
637
+ energy_pid_ckpt = gr.Textbox(
638
+ label="Energy / PID checkpoint (.ckpt) — optional",
639
+ value="model_e_pid.ckpt",
640
+ placeholder="/path/to/energy_pid.ckpt",
641
+ )
642
+ device_dd = gr.Dropdown(
643
+ choices=["cpu", "cuda:0", "cuda:1"],
644
+ value="cpu",
645
+ label="Device",
646
+ )
647
+ load_btn = gr.Button("Load model")
648
+ load_status = gr.Markdown("")
649
+ load_btn.click(
650
+ fn=load_model_ui,
651
+ inputs=[clustering_ckpt, energy_pid_ckpt, device_dd],
652
+ outputs=load_status,
653
+ )
654
+
655
+ # ---- Event selection ----
656
+ with gr.Accordion("2 · Select Event", open=True):
657
+ gr.Markdown("**Option A** — from a parquet file:")
658
+ with gr.Row():
659
+ parquet_path = gr.Textbox(
660
+ label="Parquet file path",
661
+ value="test_data.parquet",
662
+ placeholder="/path/to/events.parquet",
663
+ )
664
+ event_idx = gr.Number(label="Event index", value=0, precision=0)
665
+ parquet_info = gr.Markdown("")
666
+ parquet_path.change(
667
+ fn=_count_events_in_parquet,
668
+ inputs=parquet_path,
669
+ outputs=parquet_info,
670
+ )
671
+ load_event_btn = gr.Button("Load event from parquet")
672
+ load_event_status = gr.Markdown("")
673
+
674
+ gr.Markdown(
675
+ "---\n**Option B** — paste CSV data (one row per hit/track/particle, "
676
+ "no header, comma-separated):\n"
677
+ )
678
+
679
+ csv_hits = gr.Textbox(
680
+ label="Hits CSV (11 columns)",
681
+ lines=4,
682
+ placeholder=(
683
+ "Example (one ECAL hit, one HCAL hit):\n"
684
+ "0,0,0,0,0,1.23,1800.5,200.3,100.1,0,1\n"
685
+ "0,0,0,0,0,0.45,1900.2,-50.1,300.7,0,2"
686
+ ),
687
+ )
688
+
689
+ csv_tracks = gr.Textbox(
690
+ label="Tracks CSV (25 columns; leave empty if none)",
691
+ lines=3,
692
+ placeholder=(
693
+ "Example (one track with p≈5 GeV):\n"
694
+ "1,0,0,0,0,5.0,3.0,2.0,3.3,0,0,0,1800.0,150.0,90.0,12.5,8,0,0,0,0,0,2.9,1.9,3.2"
695
+ ),
696
+ )
697
+
698
+ csv_particles = gr.Textbox(
699
+ label="Particles (MC truth) CSV (18 columns; optional)",
700
+ lines=3,
701
+ placeholder=(
702
+ "Example (one pion, one photon):\n"
703
+ "211,1,0,0,1.2,0.5,0,0,5.2,0,0.1396,5.198,3.1,2.0,3.3,0,0,0\n"
704
+ "22,1,0,0,0.8,2.1,0,0,1.5,0,0,1.5,0.5,-0.3,1.38,0,0,0"
705
+ ),
706
+ )
707
+
708
+ csv_pandora = gr.Textbox(
709
+ label="Pandora PFOs CSV (9 columns; optional)",
710
+ lines=3,
711
+ placeholder=(
712
+ "Columns: pid, px, py, pz, ref_x, ref_y, ref_z, energy, momentum\n"
713
+ "Example (one charged pion PFO):\n"
714
+ "211,3.0,2.0,3.3,1800.0,150.0,90.0,5.2,5.198"
715
+ ),
716
+ )
717
+
718
+ csv_pfo_links = gr.Textbox(
719
+ label="Hit → Pandora Cluster links (optional; loaded from parquet)",
720
+ lines=2,
721
+ placeholder=(
722
+ "Line 1: PFO index per calo hit (comma-separated, -1 = unassigned)\n"
723
+ "Line 2: PFO index per track (comma-separated, -1 = unassigned)"
724
+ ),
725
+ )
726
+
727
+ load_event_btn.click(
728
+ fn=_load_event_into_csv,
729
+ inputs=[parquet_path, event_idx],
730
+ outputs=[csv_hits, csv_tracks, csv_particles, csv_pandora, csv_pfo_links, load_event_status],
731
+ )
732
+
733
+ # ---- Run inference ----
734
+ with gr.Accordion("3 · Results", open=True):
735
+ run_btn = gr.Button("▶ Run Inference", variant="primary")
736
+ inv_mass_output = gr.Markdown("")
737
+ gr.Markdown("### Predicted Particles (HitPF)")
738
+ particles_table = gr.Dataframe(label="Predicted particles")
739
+ gr.Markdown("### MC Truth Particles")
740
+ mc_particles_table = gr.Dataframe(label="MC truth particles (for comparison)")
741
+ gr.Markdown("### Pandora Particles")
742
+ pandora_particles_table = gr.Dataframe(label="Pandora PFO particles (for comparison)")
743
+ with gr.Row():
744
+ with gr.Column():
745
+ gr.Markdown("### Hit → HitPF Cluster 3D Map")
746
+ cluster_plot = gr.Plot(label="Hit-cluster 3D scatter (color = HitPF cluster, size = energy)")
747
+ with gr.Column():
748
+ gr.Markdown("### Hit → Pandora Cluster 3D Map")
749
+ pandora_cluster_plot = gr.Plot(label="Hit-cluster 3D scatter (color = Pandora PFO, size = energy)")
750
+ gr.Markdown("### Clustering Space 3D Map")
751
+ clustering_space_plot = gr.Plot(label="Clustering space 3D scatter (GATr regressed coordinates)")
752
+
753
+ run_btn.click(
754
+ fn=run_inference_ui,
755
+ inputs=[parquet_path, event_idx, csv_hits, csv_tracks, csv_particles, csv_pandora, csv_pfo_links],
756
+ outputs=[particles_table, cluster_plot, clustering_space_plot, pandora_cluster_plot, mc_particles_table, pandora_particles_table, inv_mass_output],
757
+ )
758
+
759
+ return demo
760
+
761
+
762
+ # ---------------------------------------------------------------------------
763
+
764
+ if __name__ == "__main__":
765
+ ap = argparse.ArgumentParser(description="HitPF Gradio UI")
766
+ ap.add_argument("--device", default="cpu", help="Default device (cpu / cuda:0 / …)")
767
+ ap.add_argument("--share", action="store_true", help="Create a public Gradio link")
768
+ cli_args = ap.parse_args()
769
+ _set_device(cli_args.device)
770
+
771
+ demo = build_app()
772
+ demo.launch(share=cli_args.share)
config_files/config_hits_track_v4.yaml ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This one uses px, py, pz instead of theta, phi, to avoid possible errors
2
+
3
+ graph_config:
4
+ only_hits: false
5
+ prediction: true
6
+ muons: true
7
+ custom_model_kwargs:
8
+ # add custom model kwargs here
9
+ n_postgn_dense_blocks: 4
10
+ clust_space_norm: none
11
+
12
+
13
+
14
+ #treename:
15
+ selection:
16
+ ### use `&`, `|`, `~` for logical operations on numpy arrays
17
+ ### can use functions from `math`, `np` (numpy), and `awkward` in the expression
18
+ #(jet_tightId==1) & (jet_no<2) & (fj_pt>200) & (fj_pt<2500) & (((sample_isQCD==0) & (fj_isQCD==0)) | ((sample_isQCD==1) & (fj_isQCD==1))) & (event_no%7!=0)
19
+ #(recojet_e>=5)
20
+
21
+ test_time_selection:
22
+ ### selection to apply at test time (i.e., when running w/ --predict)
23
+ #(jet_tightId==1) & (jet_no<2) & (fj_pt>200) & (fj_pt<2500) & (((sample_isQCD==0) & (fj_isQCD==0)) | ((sample_isQCD==1) & (fj_isQCD==1))) & (event_no%7==0)
24
+ #(recojet_e<5)
25
+
26
+ new_variables:
27
+ ### [format] name: formula
28
+ ### can use functions from `math`, `np` (numpy), and `awkward` in the expression
29
+ #pfcand_mask: awkward.JaggedArray.ones_like(pfcand_etarel)
30
+ #sv_mask: awkward.JaggedArray.ones_like(sv_etarel)
31
+ #pfcand_mask: awkward.JaggedArray.ones_like(pfcand_e)
32
+ hit_mask: ak.ones_like(hit_e)
33
+ part_mask: ak.ones_like(part_p)
34
+ hit_e_nn: hit_e
35
+ part_p1: part_p
36
+ part_theta1: part_theta
37
+ part_phi1: part_phi
38
+ part_m1: part_m
39
+ part_pid1: part_pid
40
+
41
+ preprocess:
42
+ ### method: [manual, auto] - whether to use manually specified parameters for variable standardization
43
+ ### [note]: `[var]_mask` will not be transformed even if `method=auto`
44
+ method: auto
45
+ ### data_fraction: fraction of events to use when calculating the mean/scale for the standardization
46
+ data_fraction: 0.1
47
+
48
+ inputs:
49
+ pf_points:
50
+ pad_mode: wrap
51
+ length: 25000
52
+ vars:
53
+ - [hit_x, null]
54
+ - [hit_y, null]
55
+ - [hit_z, null]
56
+ - [hit_px, null]
57
+ - [hit_py, null]
58
+ - [hit_pz, null]
59
+ pf_points_pfo:
60
+ pad_mode: wrap
61
+ length: 25000
62
+ vars:
63
+ - [hit__pandora_px, null]
64
+ - [hit__pandora_py, null]
65
+ - [hit__pandora_pz, null]
66
+ - [hit__pandora_x, null]
67
+ - [hit__pandora_y, null]
68
+ - [hit__pandora_z, null]
69
+ - [pandora_pid, null]
70
+ pf_features:
71
+ pad_mode: wrap
72
+ length: 25000
73
+ vars:
74
+ ### [format 1]: var_name (no transformation)
75
+ ### [format 2]: [var_name,
76
+ ### subtract_by(optional, default=None, no transf. if preprocess.method=manual, auto transf. if preprocess.method=auto),
77
+ ### multiply_by(optional, default=1),
78
+ ### clip_min(optional, default=-5),
79
+ ### clip_max(optional, default=5),
80
+ ### pad_value(optional, default=0)]
81
+
82
+ - [hit_p, null]
83
+ - [hit_e, null]
84
+ - [part_theta , null]
85
+ - [part_phi , null]
86
+ - [part_p , null]
87
+ - [part_m, null]
88
+ - [part_pid, null]
89
+ - [part_isDecayedInCalorimeter, null]
90
+ - [part_isDecayedInTracker, null]
91
+ - [hit_pandora_cluster_energy, null]
92
+ - [hit_pandora_pfo_energy, null]
93
+ - [hit_chis, null]
94
+ - [part_px , null]
95
+ - [part_py , null]
96
+ - [part_pz , null]
97
+ - [part_vertex_x, null]
98
+ - [part_vertex_y, null]
99
+ - [part_vertex_z, null]
100
+
101
+
102
+ pf_vectors:
103
+ length: 25000
104
+ pad_mode: wrap
105
+ vars:
106
+ - [hit_type, null] #0
107
+ - [hit_e_nn, null] #1
108
+ # #labels
109
+ # - [part_p1, null] #2
110
+ # - [part_theta1, null] #3
111
+ # - [part_phi1, null] #4
112
+ # - [part_m1, null] #15
113
+ # - [part_pid1, null] #6
114
+ pf_vectoronly:
115
+ length: 25000
116
+ pad_mode: wrap
117
+ vars:
118
+ - [hit_genlink0, null] # hit link to MC
119
+ - [hit_genlink1, null] # pandora_cluster if val data otherwise 0
120
+ - [hit_genlink2, null] # pandora_index_pfo if val data otherwise 0
121
+ - [hit_genlink3, null] # hit link to daugther
122
+
123
+
124
+ pf_mask:
125
+ length: 25000
126
+ pad_mode: constant
127
+ vars:
128
+ - [hit_mask, null]
129
+ - [part_mask, null]
130
+
131
+
132
+ labels:
133
+ ### type can be `simple`, `custom`
134
+ ### [option 1] use `simple` for binary/multi-class classification, then `value` is a list of 0-1 labels
135
+ #type: simple
136
+ #value: [
137
+ # hit_ty
138
+ # ]
139
+ ### [option 2] otherwise use `custom` to define the label, then `value` is a map
140
+ # type: custom
141
+ # value:
142
+ # target_mass: np.where(fj_isQCD, fj_genjet_sdmass, fj_gen_mass)
143
+
144
+ observers:
145
+
146
+
scripts/evaluation.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python -m src.train_lightning1 \
2
+ --data-test /eos/experiment/fcc/users/m/mgarciam/mlpf/CLD/train/Z_uds_CLD_o2_v05_eval_v1/05/pf_tree_10100.parquet \
3
+ --data-config config_files/config_hits_track_v4.yaml \
4
+ --network-config src/models/wrapper/example_mode_gatr_noise.py \
5
+ --model-prefix /eos/user/m/mgarciam/datasets_mlpf/models_trained_CLD/041225_arc_05/ \
6
+ --load-model-weights-clustering /eos/user/m/mgarciam/datasets_mlpf/models_trained_CLD/041225_arc_05/_epoch=9_step=120000.ckpt \
7
+ --load-model-weights /eos/user/m/mgarciam/datasets_mlpf/models_trained_CLD/040226_basic_ecor/_epoch=2_step=24000.ckpt \
8
+ --wandb-displayname eval_gun_drlog \
9
+ --gpus 2 \
10
+ --batch-size 20 \
11
+ --num-workers 4 \
12
+ --start-lr 1e-3 \
13
+ --num-epochs 100 \
14
+ --fetch-step 1 \
15
+ --fetch-by-files \
16
+ --log-wandb \
17
+ --wandb-projectname mlpf_debug_eval \
18
+ --wandb-entity fcc_ml \
19
+ --frac_cluster_loss 0 \
20
+ --qmin 1 \
21
+ --use-average-cc-pos 0.99 \
22
+ --correction \
23
+ --freeze-clustering \
24
+ --predict \
25
+ --name-output test_plot_hitpf2 \
26
+ --pandora
scripts/train_clustering.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python -m src.train_lightning1 \
2
+ --data-train /eos/experiment/fcc/users/m/mgarciam/mlpf/CLD/train/Z_uds_clustering_dataset_3/05/ \
3
+ --data-config config_files/config_hits_track_v4.yaml \
4
+ --network-config src/models/wrapper/example_mode_gatr_noise.py \
5
+ --model-prefix /eos/user/m/mgarciam/datasets_mlpf/models_trained_CLD/test_hitpf/ \
6
+ --num-workers 4 \
7
+ --gpus 0,1 \
8
+ --batch-size 5 \
9
+ --num-epochs 100 \
10
+ --fetch-step 1 \
11
+ --log-wandb \
12
+ --wandb-displayname CLD_clustering_training \
13
+ --wandb-projectname mlpf_debug \
14
+ --wandb-entity ml4hep \
15
+ --frac_cluster_loss 0 \
16
+ --qmin 3 \
17
+ --use-average-cc-pos 0.98 \
18
+ --train-val-split 0.98 \
19
+ --fetch-by-files \
20
+ --train-batches 10
scripts/train_energy_pid.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python -m src.train_lightning1 \
2
+ --data-train /eos/experiment/fcc/users/m/mgarciam/mlpf/CLD/train/gun_ecort/05/ \
3
+ --data-config config_files/config_hits_track_v4.yaml \
4
+ --network-config src/models/wrapper/example_mode_gatr_noise.py \
5
+ --model-prefix /eos/user/m/mgarciam/datasets_mlpf/models_trained_CLD/test_hitpf_ecor/ \
6
+ --wandb-displayname E_PID_05_basicecor_v1_1 \
7
+ --gpus 0 \
8
+ --batch-size 20 \
9
+ --num-workers 4 \
10
+ --start-lr 1e-3 \
11
+ --num-epochs 100 \
12
+ --fetch-step 1 \
13
+ --fetch-by-files \
14
+ --train-val-split 0.98 \
15
+ --train-batches 8000 \
16
+ --log-wandb \
17
+ --wandb-projectname mlpf_debug \
18
+ --wandb-entity ml4hep \
19
+ --frac_cluster_loss 0 \
20
+ --qmin 1 \
21
+ --use-average-cc-pos 0.99 \
22
+ --correction \
23
+ --freeze-clustering \
24
+ --use-gt-clusters
src/data/config.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import yaml
3
+ import copy
4
+
5
+ from src.logger.logger import _logger
6
+ from src.data.tools import _get_variable_names
7
+
8
+
9
+ def _as_list(x):
10
+ if x is None:
11
+ return None
12
+ elif isinstance(x, (list, tuple)):
13
+ return x
14
+ else:
15
+ return [x]
16
+
17
+
18
+ def _md5(fname):
19
+ '''https://stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file'''
20
+ import hashlib
21
+ hash_md5 = hashlib.md5()
22
+ with open(fname, "rb") as f:
23
+ for chunk in iter(lambda: f.read(4096), b""):
24
+ hash_md5.update(chunk)
25
+ return hash_md5.hexdigest()
26
+
27
+
28
+ class DataConfig(object):
29
+ r"""Data loading configuration.
30
+ """
31
+
32
+ def __init__(self, print_info=True, **kwargs):
33
+ opts = {
34
+ 'treename': None,
35
+ 'selection': None,
36
+ 'test_time_selection': None,
37
+ 'preprocess': {'method': 'manual', 'data_fraction': 0.1, 'params': None},
38
+ 'new_variables': {},
39
+ 'inputs': {},
40
+ 'labels': {},
41
+ 'observers': [],
42
+ 'monitor_variables': [],
43
+ 'weights': None,
44
+ 'graph_config': {},
45
+ 'custom_model_kwargs': {}
46
+ }
47
+ for k, v in kwargs.items():
48
+ if v is not None:
49
+ if isinstance(opts[k], dict):
50
+ opts[k].update(v)
51
+ else:
52
+ opts[k] = v
53
+ # only information in ``self.options'' will be persisted when exporting to YAML
54
+ self.options = opts
55
+ if print_info:
56
+ _logger.debug(opts)
57
+
58
+ self.selection = opts['selection']
59
+ self.test_time_selection = opts['test_time_selection'] if opts['test_time_selection'] else self.selection
60
+ self.var_funcs = copy.deepcopy(opts['new_variables'])
61
+ # preprocessing config
62
+ self.preprocess = opts['preprocess']
63
+ self._auto_standardization = opts['preprocess']['method'].lower().startswith('auto')
64
+ self._missing_standardization_info = False
65
+ self.preprocess_params = opts['preprocess']['params'] if opts['preprocess']['params'] is not None else {}
66
+ # inputs
67
+ self.input_names = tuple(opts['inputs'].keys())
68
+ self.input_dicts = {k: [] for k in self.input_names}
69
+ self.input_shapes = {}
70
+ for k, o in opts['inputs'].items():
71
+ self.input_shapes[k] = (-1, len(o['vars']), o['length'])
72
+ for v in o['vars']:
73
+ v = _as_list(v)
74
+ self.input_dicts[k].append(v[0])
75
+
76
+ if opts['preprocess']['params'] is None:
77
+
78
+ def _get(idx, default):
79
+ try:
80
+ return v[idx]
81
+ except IndexError:
82
+ return default
83
+
84
+ params = {'length': o['length'], 'pad_mode': o.get('pad_mode', 'constant').lower(),
85
+ 'center': _get(1, 'auto' if self._auto_standardization else None),
86
+ 'scale': _get(2, 1), 'min': _get(3, -5), 'max': _get(4, 5), 'pad_value': _get(5, 0)}
87
+
88
+ if v[0] in self.preprocess_params and params != self.preprocess_params[v[0]]:
89
+ raise RuntimeError(
90
+ 'Incompatible info for variable %s, had: \n %s\nnow got:\n %s' %
91
+ (v[0], str(self.preprocess_params[v[0]]), str(params)))
92
+ if k.endswith('_mask') and params['pad_mode'] != 'constant':
93
+ raise RuntimeError('The `pad_mode` must be set to `constant` for the mask input `%s`' % k)
94
+ if params['center'] == 'auto':
95
+ self._missing_standardization_info = True
96
+ self.preprocess_params[v[0]] = params
97
+
98
+ # observers
99
+ self.observer_names = tuple(opts['observers'])
100
+ # monitor variables
101
+ self.monitor_variables = tuple(opts['monitor_variables'])
102
+ # Z variables: returned as `Z` in the dataloader (use monitor_variables for training, observers for eval)
103
+ self.z_variables = self.observer_names if len(self.observer_names) > 0 else self.monitor_variables
104
+
105
+ # remove self mapping from var_funcs
106
+ for k, v in self.var_funcs.items():
107
+ if k == v:
108
+ del self.var_funcs[k]
109
+
110
+ if print_info:
111
+ def _log(msg, *args, **kwargs):
112
+ _logger.info(msg, *args, color='lightgray', **kwargs)
113
+ _log('preprocess config: %s', str(self.preprocess))
114
+ _log('selection: %s', str(self.selection))
115
+ _log('test_time_selection: %s', str(self.test_time_selection))
116
+ _log('var_funcs:\n - %s', '\n - '.join(str(it) for it in self.var_funcs.items()))
117
+ _log('input_names: %s', str(self.input_names))
118
+ _log('input_dicts:\n - %s', '\n - '.join(str(it) for it in self.input_dicts.items()))
119
+ _log('input_shapes:\n - %s', '\n - '.join(str(it) for it in self.input_shapes.items()))
120
+ _log('preprocess_params:\n - %s', '\n - '.join(str(it) for it in self.preprocess_params.items()))
121
+ #_log('label_names: %s', str(self.label_names))
122
+ _log('observer_names: %s', str(self.observer_names))
123
+ _log('monitor_variables: %s', str(self.monitor_variables))
124
+ if opts['weights'] is not None:
125
+ if self.use_precomputed_weights:
126
+ _log('weight: %s' % self.var_funcs[self.weight_name])
127
+ else:
128
+ for k in ['reweight_method', 'reweight_basewgt', 'reweight_branches', 'reweight_bins',
129
+ 'reweight_classes', 'class_weights', 'reweight_threshold',
130
+ 'reweight_discard_under_overflow']:
131
+ _log('%s: %s' % (k, getattr(self, k)))
132
+
133
+ # parse config
134
+ self.keep_branches = set()
135
+ aux_branches = set()
136
+ # selection
137
+ if self.selection:
138
+ aux_branches.update(_get_variable_names(self.selection))
139
+ # test time selection
140
+ if self.test_time_selection:
141
+ aux_branches.update(_get_variable_names(self.test_time_selection))
142
+ # var_funcs
143
+ self.keep_branches.update(self.var_funcs.keys())
144
+ for expr in self.var_funcs.values():
145
+ aux_branches.update(_get_variable_names(expr))
146
+ # inputs
147
+ for names in self.input_dicts.values():
148
+ self.keep_branches.update(names)
149
+ # labels
150
+ #self.keep_branches.update(self.label_names)
151
+ # weight
152
+ #if self.weight_name:
153
+ # self.keep_branches.add(self.weight_name)
154
+ # if not self.use_precomputed_weights:
155
+ # aux_branches.update(self.reweight_branches)
156
+ # aux_branches.update(self.reweight_classes)
157
+ # observers
158
+ self.keep_branches.update(self.observer_names)
159
+ # monitor variables
160
+ self.keep_branches.update(self.monitor_variables)
161
+ # keep and drop
162
+ self.drop_branches = (aux_branches - self.keep_branches)
163
+ self.load_branches = (aux_branches | self.keep_branches) - set(self.var_funcs.keys()) #- {self.weight_name, }
164
+ if print_info:
165
+ _logger.debug('drop_branches:\n %s', ','.join(self.drop_branches))
166
+ _logger.debug('load_branches:\n %s', ','.join(self.load_branches))
167
+
168
+ def __getattr__(self, name):
169
+ return self.options[name]
170
+
171
+ def dump(self, fp):
172
+ with open(fp, 'w') as f:
173
+ yaml.safe_dump(self.options, f, sort_keys=False)
174
+
175
+ @classmethod
176
+ def load(cls, fp, load_observers=True, load_reweight_info=True, extra_selection=None, extra_test_selection=None):
177
+ with open(fp) as f:
178
+ options = yaml.safe_load(f)
179
+ if not load_observers:
180
+ options['observers'] = None
181
+ if not load_reweight_info:
182
+ options['weights'] = None
183
+ if extra_selection:
184
+ options['selection'] = '(%s) & (%s)' % (options['selection'], extra_selection)
185
+ if extra_test_selection:
186
+ if 'test_time_selection' not in options:
187
+ raise RuntimeError('`test_time_selection` is not defined in the yaml file!')
188
+ options['test_time_selection'] = '(%s) & (%s)' % (options['test_time_selection'], extra_test_selection)
189
+ return cls(**options)
190
+
191
+ def copy(self):
192
+ return self.__class__(print_info=False, **copy.deepcopy(self.options))
193
+
194
+ def __copy__(self):
195
+ return self.copy()
196
+
197
+ def __deepcopy__(self, memo):
198
+ return self.copy()
199
+
200
+ def export_json(self, fp):
201
+ import json
202
+ j = {'output_names': self.label_value, 'input_names': self.input_names}
203
+ for k, v in self.input_dicts.items():
204
+ j[k] = {'var_names': v, 'var_infos': {}}
205
+ for var_name in v:
206
+ j[k]['var_length'] = self.preprocess_params[var_name]['length']
207
+ info = self.preprocess_params[var_name]
208
+ j[k]['var_infos'][var_name] = {
209
+ 'median': 0 if info['center'] is None else info['center'],
210
+ 'norm_factor': info['scale'],
211
+ 'replace_inf_value': 0,
212
+ 'lower_bound': -1e32 if info['center'] is None else info['min'],
213
+ 'upper_bound': 1e32 if info['center'] is None else info['max'],
214
+ 'pad': info['pad_value']
215
+ }
216
+ with open(fp, 'w') as f:
217
+ json.dump(j, f, indent=2)
218
+
src/data/fileio.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import awkward as ak
3
+ import tqdm
4
+ import traceback
5
+ from src.data.tools import _concat, _concat_records
6
+
7
+
8
+
9
+ def _read_hdf5(filepath, branches, load_range=None):
10
+ import tables
11
+ tables.set_blosc_max_threads(4)
12
+ with tables.open_file(filepath) as f:
13
+ outputs = {k: getattr(f.root, k)[:] for k in branches}
14
+ if load_range is None:
15
+ load_range = (0, 1)
16
+ start = math.trunc(load_range[0] * len(outputs[branches[0]]))
17
+ stop = max(start + 1, math.trunc(load_range[1] * len(outputs[branches[0]])))
18
+ for k, v in outputs.items():
19
+ outputs[k] = v[start:stop]
20
+ return ak.Array(outputs)
21
+
22
+
23
+ def _read_root(filepath, branches, load_range=None, treename=None):
24
+ import uproot
25
+ with uproot.open(filepath) as f:
26
+ if treename is None:
27
+ treenames = set([k.split(';')[0] for k, v in f.items() if getattr(v, 'classname', '') == 'TTree'])
28
+ if len(treenames) == 1:
29
+ treename = treenames.pop()
30
+ else:
31
+ raise RuntimeError(
32
+ 'Need to specify `treename` as more than one trees are found in file %s: %s' %
33
+ (filepath, str(branches)))
34
+ tree = f[treename]
35
+ if load_range is not None:
36
+ start = math.trunc(load_range[0] * tree.num_entries)
37
+ stop = max(start + 1, math.trunc(load_range[1] * tree.num_entries))
38
+ else:
39
+ start, stop = None, None
40
+ outputs = tree.arrays(filter_name=branches, entry_start=start, entry_stop=stop)
41
+ return outputs
42
+
43
+
44
+ def _read_awkd(filepath, branches, load_range=None):
45
+ import awkward0
46
+ with awkward0.load(filepath) as f:
47
+ outputs = {k: f[k] for k in branches}
48
+ if load_range is None:
49
+ load_range = (0, 1)
50
+ start = math.trunc(load_range[0] * len(outputs[branches[0]]))
51
+ stop = max(start + 1, math.trunc(load_range[1] * len(outputs[branches[0]])))
52
+ for k, v in outputs.items():
53
+ outputs[k] = ak.from_awkward0(v[start:stop])
54
+ return ak.Array(outputs)
55
+
56
+
57
+ def _slice_record(record, start, stop):
58
+ sliced_fields = {}
59
+ for field in record.fields:
60
+ sliced_fields[field] = record[field][start:stop]
61
+ return ak.Record(sliced_fields)
62
+
63
+ def _read_parquet(filepath, load_range=None):
64
+ outputs = ak.from_parquet(filepath)
65
+ len_outputs = len(outputs["X_track"])
66
+ if load_range is not None:
67
+ start = math.trunc(load_range[0] * len_outputs)
68
+ stop = max(start + 1, math.trunc(load_range[1] * len_outputs))
69
+ outputs = _slice_record(outputs, start, stop)
70
+
71
+ return outputs
72
+
73
+
74
+ def _read_files(filelist, load_range=None, show_progressbar=False, **kwargs):
75
+ import os
76
+ table = []
77
+ if show_progressbar:
78
+ filelist = tqdm.tqdm(filelist)
79
+ for filepath in filelist:
80
+ ext = os.path.splitext(filepath)[1]
81
+ if ext not in ('.h5', '.root', '.awkd', '.parquet'):
82
+ raise RuntimeError('File %s of type `%s` is not supported!' % (filepath, ext))
83
+ a = _read_parquet(filepath, load_range=load_range)
84
+ if a is not None:
85
+ table.append(a)
86
+ table = _concat_records(table) # ak.Array
87
+ if len(table["X_track"]) == 0:
88
+ raise RuntimeError(f'Zero entries loaded when reading files {filelist} with `load_range`={load_range}.')
89
+ return table
90
+
91
+
92
+ def _write_root(file, table, treename='Events', compression=-1, step=1048576):
93
+ import uproot
94
+ if compression == -1:
95
+ compression = uproot.LZ4(4)
96
+ with uproot.recreate(file, compression=compression) as fout:
97
+ tree = fout.mktree(treename, {k: v.dtype for k, v in table.items()})
98
+ start = 0
99
+ while start < len(list(table.values())[0]) - 1:
100
+ tree.extend({k: v[start:start + step] for k, v in table.items()})
101
+ start += step
src/data/preprocess.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import glob
3
+ import copy
4
+ import numpy as np
5
+ import awkward as ak
6
+
7
+ from src.data.tools import _get_variable_names, _eval_expr
8
+ from src.data.fileio import _read_files
9
+
10
+
11
+ def _apply_selection(table, selection):
12
+ if selection is None:
13
+ return table
14
+ selected = ak.values_astype(_eval_expr(selection, table), 'bool')
15
+ return table[selected]
16
+
17
+
18
+ def _build_new_variables(table, funcs):
19
+ if funcs is None:
20
+ return table
21
+ for k, expr in funcs.items():
22
+ if k in table.fields:
23
+ continue
24
+ table[k] = _eval_expr(expr, table)
25
+ return table
26
+
27
+
28
+ def _clean_up(table, drop_branches):
29
+ columns = [k for k in table.fields if k not in drop_branches]
30
+ return table[columns]
31
+
32
+
33
+ def _build_weights(table, data_config, reweight_hists=None):
34
+ if data_config.weight_name is None:
35
+ raise RuntimeError('Error when building weights: `weight_name` is None!')
36
+ if data_config.use_precomputed_weights:
37
+ return ak.to_numpy(table[data_config.weight_name])
38
+ else:
39
+ x_var, y_var = data_config.reweight_branches
40
+ x_bins, y_bins = data_config.reweight_bins
41
+ rwgt_sel = None
42
+ if data_config.reweight_discard_under_overflow:
43
+ rwgt_sel = (table[x_var] >= min(x_bins)) & (table[x_var] <= max(x_bins)) & \
44
+ (table[y_var] >= min(y_bins)) & (table[y_var] <= max(y_bins))
45
+ # init w/ wgt=0: events not belonging to any class in `reweight_classes` will get a weight of 0 at the end
46
+ wgt = np.zeros(len(table), dtype='float32')
47
+ sum_evts = 0
48
+ if reweight_hists is None:
49
+ reweight_hists = data_config.reweight_hists
50
+ for label, hist in reweight_hists.items():
51
+ pos = table[label] == 1
52
+ if rwgt_sel is not None:
53
+ pos = (pos & rwgt_sel)
54
+ rwgt_x_vals = ak.to_numpy(table[x_var][pos])
55
+ rwgt_y_vals = ak.to_numpy(table[y_var][pos])
56
+ x_indices = np.clip(np.digitize(
57
+ rwgt_x_vals, x_bins) - 1, a_min=0, a_max=len(x_bins) - 2)
58
+ y_indices = np.clip(np.digitize(
59
+ rwgt_y_vals, y_bins) - 1, a_min=0, a_max=len(y_bins) - 2)
60
+ wgt[pos] = hist[x_indices, y_indices]
61
+ sum_evts += np.sum(pos)
62
+ if sum_evts != len(table):
63
+ warn(
64
+ 'Not all selected events used in the reweighting. '
65
+ 'Check consistency between `selection` and `reweight_classes` definition, or with the `reweight_vars` binnings '
66
+ '(under- and overflow bins are discarded by default, unless `reweight_discard_under_overflow` is set to `False` in the `weights` section).',
67
+ )
68
+ if data_config.reweight_basewgt:
69
+ wgt *= ak.to_numpy(table[data_config.basewgt_name])
70
+ return wgt
71
+
72
+
73
+ class AutoStandardizer(object):
74
+ r"""AutoStandardizer.
75
+ Class to compute the variable standardization information.
76
+ Arguments:
77
+ filelist (list): list of files to be loaded.
78
+ data_config (DataConfig): object containing data format information.
79
+ """
80
+
81
+ def __init__(self, filelist, data_config):
82
+ if isinstance(filelist, dict):
83
+ filelist = sum(filelist.values(), [])
84
+ self._filelist = filelist if isinstance(
85
+ filelist, (list, tuple)) else glob.glob(filelist)
86
+ self._data_config = data_config.copy()
87
+ self.load_range = (0, data_config.preprocess.get('data_fraction', 0.1))
88
+
89
+ def read_file(self, filelist):
90
+ self.keep_branches = set()
91
+ self.load_branches = set()
92
+ for k, params in self._data_config.preprocess_params.items():
93
+ if params['center'] == 'auto':
94
+ self.keep_branches.add(k)
95
+ if k in self._data_config.var_funcs:
96
+ expr = self._data_config.var_funcs[k]
97
+ self.load_branches.update(_get_variable_names(expr))
98
+ else:
99
+ self.load_branches.add(k)
100
+ if self._data_config.selection:
101
+ self.load_branches.update(_get_variable_names(self._data_config.selection))
102
+
103
+ table = _read_files(filelist, self.load_branches, self.load_range,
104
+ show_progressbar=True, treename=self._data_config.treename)
105
+ table = _apply_selection(table, self._data_config.selection)
106
+ table = _build_new_variables(
107
+ table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches})
108
+ table = _clean_up(table, self.load_branches - self.keep_branches)
109
+ return table
110
+
111
+ def make_preprocess_params(self, table):
112
+
113
+ preprocess_params = copy.deepcopy(self._data_config.preprocess_params)
114
+ for k, params in self._data_config.preprocess_params.items():
115
+ if params['center'] == 'auto':
116
+ if k.endswith('_mask'):
117
+ params['center'] = None
118
+ else:
119
+ a = ak.to_numpy(ak.flatten(table[k], axis=None))
120
+ # check for NaN
121
+ if np.any(np.isnan(a)):
122
+
123
+ time.sleep(10)
124
+ a = np.nan_to_num(a)
125
+ low, center, high = np.percentile(a, [16, 50, 84])
126
+ scale = max(high - center, center - low)
127
+ scale = 1 if scale == 0 else 1. / scale
128
+ params['center'] = float(center)
129
+ params['scale'] = float(scale)
130
+ preprocess_params[k] = params
131
+
132
+ return preprocess_params
133
+
134
+ def produce(self, output=None):
135
+ table = self.read_file(self._filelist)
136
+ preprocess_params = self.make_preprocess_params(table)
137
+ self._data_config.preprocess_params = preprocess_params
138
+ # must also propogate the changes to `data_config.options` so it can be persisted
139
+ self._data_config.options['preprocess']['params'] = preprocess_params
140
+ if output:
141
+ self._data_config.dump(output)
142
+ return self._data_config
143
+
144
+
145
+ class WeightMaker(object):
146
+ r"""WeightMaker.
147
+ Class to make reweighting information.
148
+ Arguments:
149
+ filelist (list): list of files to be loaded.
150
+ data_config (DataConfig): object containing data format information.
151
+ """
152
+
153
+ def __init__(self, filelist, data_config):
154
+ if isinstance(filelist, dict):
155
+ filelist = sum(filelist.values(), [])
156
+ self._filelist = filelist if isinstance(filelist, (list, tuple)) else glob.glob(filelist)
157
+ self._data_config = data_config.copy()
158
+
159
+ def read_file(self, filelist):
160
+ self.keep_branches = set(self._data_config.reweight_branches + self._data_config.reweight_classes +
161
+ (self._data_config.basewgt_name,))
162
+ self.load_branches = set()
163
+ for k in self.keep_branches:
164
+ if k in self._data_config.var_funcs:
165
+ expr = self._data_config.var_funcs[k]
166
+ self.load_branches.update(_get_variable_names(expr))
167
+ else:
168
+ self.load_branches.add(k)
169
+ if self._data_config.selection:
170
+ self.load_branches.update(_get_variable_names(self._data_config.selection))
171
+ table = _read_files(filelist, self.load_branches, show_progressbar=True, treename=self._data_config.treename)
172
+ table = _apply_selection(table, self._data_config.selection)
173
+ table = _build_new_variables(
174
+ table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches})
175
+ table = _clean_up(table, self.load_branches - self.keep_branches)
176
+ return table
177
+
178
+ def make_weights(self, table):
179
+ x_var, y_var = self._data_config.reweight_branches
180
+ x_bins, y_bins = self._data_config.reweight_bins
181
+ if not self._data_config.reweight_discard_under_overflow:
182
+ # clip variables to be within bin ranges
183
+ x_min, x_max = min(x_bins), max(x_bins)
184
+ y_min, y_max = min(y_bins), max(y_bins)
185
+ table[x_var] = np.clip(table[x_var], min(x_bins), max(x_bins))
186
+ table[y_var] = np.clip(table[y_var], min(y_bins), max(y_bins))
187
+ sum_evts = 0
188
+ max_weight = 0.9
189
+ raw_hists = {}
190
+ class_events = {}
191
+ result = {}
192
+ for label in self._data_config.reweight_classes:
193
+ pos = (table[label] == 1)
194
+ x = ak.to_numpy(table[x_var][pos])
195
+ y = ak.to_numpy(table[y_var][pos])
196
+ hist, _, _ = np.histogram2d(x, y, bins=self._data_config.reweight_bins)
197
+ sum_evts += hist.sum()
198
+ if self._data_config.reweight_basewgt:
199
+ w = ak.to_numpy(table[self._data_config.basewgt_name][pos])
200
+ hist, _, _ = np.histogram2d(x, y, weights=w, bins=self._data_config.reweight_bins)
201
+
202
+ raw_hists[label] = hist.astype('float32')
203
+ result[label] = hist.astype('float32')
204
+ if sum_evts != len(table):
205
+ time.sleep(10)
206
+
207
+ if self._data_config.reweight_method == 'flat':
208
+ for label, classwgt in zip(self._data_config.reweight_classes, self._data_config.class_weights):
209
+ hist = result[label]
210
+ threshold_ = np.median(hist[hist > 0]) * 0.01
211
+ nonzero_vals = hist[hist > threshold_]
212
+ min_val, med_val = np.min(nonzero_vals), np.median(hist) # not really used
213
+ ref_val = np.percentile(nonzero_vals, self._data_config.reweight_threshold)
214
+ # wgt: bins w/ 0 elements will get a weight of 0; bins w/ content<ref_val will get 1
215
+ wgt = np.clip(np.nan_to_num(ref_val / hist, posinf=0), 0, 1)
216
+ result[label] = wgt
217
+ # divide by classwgt here will effective increase the weight later
218
+ class_events[label] = np.sum(raw_hists[label] * wgt) / classwgt
219
+ elif self._data_config.reweight_method == 'ref':
220
+ # use class 0 as the reference
221
+ hist_ref = raw_hists[self._data_config.reweight_classes[0]]
222
+ for label, classwgt in zip(self._data_config.reweight_classes, self._data_config.class_weights):
223
+ # wgt: bins w/ 0 elements will get a weight of 0; bins w/ content<ref_val will get 1
224
+ ratio = np.nan_to_num(hist_ref / result[label], posinf=0)
225
+ upper = np.percentile(ratio[ratio > 0], 100 - self._data_config.reweight_threshold)
226
+ wgt = np.clip(ratio / upper, 0, 1) # -> [0,1]
227
+ result[label] = wgt
228
+ # divide by classwgt here will effective increase the weight later
229
+ class_events[label] = np.sum(raw_hists[label] * wgt) / classwgt
230
+ # ''equalize'' all classes
231
+ # multiply by max_weight (<1) to add some randomness in the sampling
232
+ min_nevt = min(class_events.values()) * max_weight
233
+ for label in self._data_config.reweight_classes:
234
+ class_wgt = float(min_nevt) / class_events[label]
235
+ result[label] *= class_wgt
236
+
237
+ if self._data_config.reweight_basewgt:
238
+ wgts = _build_weights(table, self._data_config, reweight_hists=result)
239
+ wgt_ref = np.percentile(wgts, 100 - self._data_config.reweight_threshold)
240
+ for label in self._data_config.reweight_classes:
241
+ result[label] /= wgt_ref
242
+
243
+ return result
244
+
245
+ def produce(self, output=None):
246
+ table = self.read_file(self._filelist)
247
+ wgts = self.make_weights(table)
248
+ self._data_config.reweight_hists = wgts
249
+ # must also propogate the changes to `data_config.options` so it can be persisted
250
+ self._data_config.options['weights']['reweight_hists'] = {k: v.tolist() for k, v in wgts.items()}
251
+ if output:
252
+ self._data_config.dump(output)
253
+ return self._data_config
src/data/tools.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+ import awkward as ak
5
+
6
+ def build_dummy_array(num, dtype=np.int64):
7
+ return ak.Array(
8
+ ak.contents.ListOffsetArray(
9
+ ak.index.Index64(np.zeros(num + 1, dtype=np.int64)),
10
+ ak.from_numpy(np.array([], dtype=dtype), highlevel=False),
11
+ )
12
+ )
13
+
14
+ def _concat_records(table):
15
+ table1 = {k : ak.from_iter([record[k][event] for record in table for event in range(len(record[k])) ]) for k in table[0].fields}
16
+ for k in table1.keys():
17
+ if len(ak.flatten(table1[k])) == 0:
18
+ table1[k] = build_dummy_array(len(table1[k]), np.float32)
19
+ table1 = ak.Record(table1)
20
+ return table1
21
+
22
+ def _concat(arrays, axis=0):
23
+ if len(arrays) == 0:
24
+ return np.array([])
25
+ if isinstance(arrays[0], np.ndarray):
26
+ return np.concatenate(arrays, axis=axis)
27
+ else:
28
+ return ak.concatenate(arrays, axis=axis)
29
+
30
+
31
+ def _stack(arrays, axis=1):
32
+ if len(arrays) == 0:
33
+ return np.array([])
34
+ if isinstance(arrays[0], np.ndarray):
35
+ return np.stack(arrays, axis=axis)
36
+ else:
37
+ return ak.concatenate(arrays, axis=axis)
38
+
39
+
40
+ def _pad_vector(a, value=-1, dtype="float32"):
41
+ maxlen = 2000
42
+ maxlen2 = 5
43
+
44
+ x = (np.ones((len(a), maxlen, maxlen2)) * value).astype(dtype)
45
+ for idx, s in enumerate(a):
46
+ for idx_vec, s_vec in enumerate(s):
47
+ x[idx, idx_vec, : len(s_vec)] = s_vec
48
+ return x
49
+
50
+
51
+ def _pad(a, maxlen, value=0, dtype="float32"):
52
+ if isinstance(a, np.ndarray) and a.ndim >= 2 and a.shape[1] == maxlen:
53
+ return a
54
+ elif isinstance(a, ak.Array):
55
+ if a.ndim == 1:
56
+ a = ak.unflatten(a, 1)
57
+ a = ak.fill_none(ak.pad_none(a, maxlen, clip=True), value)
58
+ return ak.values_astype(a, dtype)
59
+ else:
60
+ x = (np.ones((len(a), maxlen)) * value).astype(dtype)
61
+ for idx, s in enumerate(a):
62
+ if not len(s):
63
+ continue
64
+ trunc = s[:maxlen].astype(dtype)
65
+ x[idx, : len(trunc)] = trunc
66
+ return x
67
+
68
+
69
+ def _repeat_pad(a, maxlen, shuffle=False, dtype="float32"):
70
+ x = ak.to_numpy(ak.flatten(a))
71
+ x = np.tile(x, int(np.ceil(len(a) * maxlen / len(x))))
72
+ if shuffle:
73
+ np.random.shuffle(x)
74
+ x = x[: len(a) * maxlen].reshape((len(a), maxlen))
75
+ mask = _pad(ak.zeros_like(a), maxlen, value=1)
76
+ x = _pad(a, maxlen) + mask * x
77
+ return ak.values_astype(x, dtype)
78
+
79
+
80
+ def _clip(a, a_min, a_max):
81
+ try:
82
+ return np.clip(a, a_min, a_max)
83
+ except ValueError:
84
+ return ak.unflatten(np.clip(ak.flatten(a), a_min, a_max), ak.num(a))
85
+
86
+
87
+ def _knn(support, query, k, n_jobs=1):
88
+ from scipy.spatial import cKDTree
89
+
90
+ kdtree = cKDTree(support)
91
+ d, idx = kdtree.query(query, k, n_jobs=n_jobs)
92
+ return idx
93
+
94
+
95
+ def _batch_knn(supports, queries, k, maxlen_s, maxlen_q=None, n_jobs=1):
96
+ assert len(supports) == len(queries)
97
+ if maxlen_q is None:
98
+ maxlen_q = maxlen_s
99
+ batch_knn_idx = np.ones((len(supports), maxlen_q, k), dtype="int32") * (
100
+ maxlen_s - 1
101
+ )
102
+ for i, (s, q) in enumerate(zip(supports, queries)):
103
+ batch_knn_idx[i, : len(q[:maxlen_q]), :] = _knn(
104
+ s[:maxlen_s], q[:maxlen_q], k, n_jobs=n_jobs
105
+ ).reshape(
106
+ (-1, k)
107
+ ) # (len(q), k)
108
+ return batch_knn_idx
109
+
110
+
111
+ def _batch_permute_indices(array, maxlen):
112
+ batch_permute_idx = np.tile(np.arange(maxlen), (len(array), 1))
113
+ for i, a in enumerate(array):
114
+ batch_permute_idx[i, : len(a)] = np.random.permutation(len(a[:maxlen]))
115
+ return batch_permute_idx
116
+
117
+
118
+ def _batch_argsort(array, maxlen):
119
+ batch_argsort_idx = np.tile(np.arange(maxlen), (len(array), 1))
120
+ for i, a in enumerate(array):
121
+ batch_argsort_idx[i, : len(a)] = np.argsort(a[:maxlen])
122
+ return batch_argsort_idx
123
+
124
+
125
+ def _batch_gather(array, indices):
126
+ out = array.zeros_like()
127
+ for i, (a, idx) in enumerate(zip(array, indices)):
128
+ maxlen = min(len(a), len(idx))
129
+ out[i][:maxlen] = a[idx[:maxlen]]
130
+ return out
131
+
132
+
133
+ def _p4_from_pxpypze(px, py, pz, energy):
134
+ import vector
135
+
136
+ vector.register_awkward()
137
+ return vector.zip({"px": px, "py": py, "pz": pz, "energy": energy})
138
+
139
+
140
+ def _p4_from_ptetaphie(pt, eta, phi, energy):
141
+ import vector
142
+
143
+ vector.register_awkward()
144
+ return vector.zip({"pt": pt, "eta": eta, "phi": phi, "energy": energy})
145
+
146
+
147
+ def _p4_from_ptetaphim(pt, eta, phi, mass):
148
+ import vector
149
+
150
+ vector.register_awkward()
151
+ return vector.zip({"pt": pt, "eta": eta, "phi": phi, "mass": mass})
152
+
153
+
154
+ def _get_variable_names(expr, exclude=["awkward", "ak", "np", "numpy", "math"]):
155
+ import ast
156
+
157
+ root = ast.parse(expr)
158
+ return sorted(
159
+ {
160
+ node.id
161
+ for node in ast.walk(root)
162
+ if isinstance(node, ast.Name) and not node.id.startswith("_")
163
+ }
164
+ - set(exclude)
165
+ )
166
+
167
+
168
+ def _eval_expr(expr, table):
169
+ tmp = {k: table[k] for k in _get_variable_names(expr)}
170
+ tmp.update(
171
+ {
172
+ "math": math,
173
+ "np": np,
174
+ "numpy": np,
175
+ "ak": ak,
176
+ "awkward": ak,
177
+ "_concat": _concat,
178
+ "_stack": _stack,
179
+ "_pad": _pad,
180
+ "_repeat_pad": _repeat_pad,
181
+ "_clip": _clip,
182
+ "_batch_knn": _batch_knn,
183
+ "_batch_permute_indices": _batch_permute_indices,
184
+ "_batch_argsort": _batch_argsort,
185
+ "_batch_gather": _batch_gather,
186
+ "_p4_from_pxpypze": _p4_from_pxpypze,
187
+ "_p4_from_ptetaphie": _p4_from_ptetaphie,
188
+ "_p4_from_ptetaphim": _p4_from_ptetaphim,
189
+ }
190
+ )
191
+ return eval(expr, tmp)
src/dataset/dataclasses.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, List, Optional
3
+ import torch
4
+ import numpy as np
5
+
6
+
7
+ @dataclass
8
+ class PandoraFeatures:
9
+ # Features associated to the hits
10
+ pandora_cluster: Optional[Any] = None
11
+ pandora_cluster_energy: Optional[Any] = None
12
+ pfo_energy: Optional[Any] = None
13
+ pandora_mom: Optional[Any] = None
14
+ pandora_ref_point: Optional[Any] = None
15
+ pandora_pid: Optional[Any] = None
16
+ pandora_pfo_link: Optional[Any] = None
17
+ pandora_mom_components: Optional[Any] = None
18
+
19
+
20
+ @dataclass
21
+ class Hits:
22
+ pos_xyz_hits: Any
23
+ pos_pxpypz: Any
24
+ pos_pxpypz_calo: Any
25
+ p_hits: Any
26
+ e_hits: Any
27
+ hit_particle_link: Any
28
+ pandora_features: Any # type PandoraFeatures
29
+ hit_type_feature: Any
30
+ chi_squared_tracks: Any
31
+ hit_type_one_hot: Any
32
+
33
+
34
+ @classmethod
35
+ def from_data(cls, output, number_hits, args, number_part):
36
+ hit_particle_link_hits = torch.tensor(output["ygen_hit"])
37
+ if len(output["ygen_track"])>0:
38
+ hit_particle_link_tracks= torch.tensor(output["ygen_track"])
39
+ hit_particle_link = torch.cat((hit_particle_link_hits, hit_particle_link_tracks), dim=0)
40
+ else:
41
+ hit_particle_link = hit_particle_link_hits
42
+ # hit_particle_link_calomother = torch.cat((hit_particle_link_hits_calomother, hit_particle_link_tracks), dim=0)
43
+ if args.pandora:
44
+ pandora_features = PandoraFeatures()
45
+ X_pandora = torch.tensor(output["X_pandora"])
46
+ pfo_link_hits = torch.tensor(output["pfo_calohit"])
47
+ if len(output["pfo_track"])>0:
48
+ pfo_link_tracks = torch.tensor(output["pfo_track"])
49
+ pfo_link = torch.cat((pfo_link_hits, pfo_link_tracks), dim=0)
50
+ else:
51
+ pfo_link = pfo_link_hits
52
+ pandora_features.pandora_pfo_link = pfo_link
53
+ pfo_link_temp = pfo_link.clone()
54
+ pfo_link_temp[pfo_link_temp==-1]=0
55
+
56
+ pandora_features.pandora_mom = X_pandora[pfo_link_temp, 8]
57
+ pandora_features.pandora_ref_point = X_pandora[pfo_link_temp, 4:7]
58
+ pandora_features.pandora_mom_components = X_pandora[pfo_link_temp, 1:4]
59
+ pandora_features.pandora_pid = X_pandora[pfo_link_temp, 0]
60
+ pandora_features.pfo_energy = X_pandora[pfo_link_temp, 7]
61
+ pandora_features.pandora_mom[pfo_link==-1]=0
62
+ pandora_features.pandora_mom_components[pfo_link==-1]=0
63
+ pandora_features.pandora_ref_point[pfo_link==-1]=0
64
+ pandora_features.pandora_pid[pfo_link==-1]=0
65
+ pandora_features.pfo_energy[pfo_link==-1]=0
66
+
67
+ else:
68
+ pandora_features = None
69
+ X_hit = torch.tensor(output["X_hit"])
70
+ if len(output["X_track"])>0:
71
+ X_track = torch.tensor(output["X_track"])
72
+ # obtain hit type
73
+
74
+ hit_type_feature_hit = X_hit[:,10]+1 #tyep (1,2,3,4 hits)
75
+ if len(output["X_track"])>0:
76
+ hit_type_feature_track = X_track[:,0] #elemtype (1 for tracks)
77
+ hit_type_feature = torch.cat((hit_type_feature_hit, hit_type_feature_track), dim=0).to(torch.int64)
78
+ else:
79
+ hit_type_feature = hit_type_feature_hit.to(torch.int64)
80
+ # obtain the position of the hits and the energies and p
81
+ pos_xyz_hits_hits = X_hit[:,6:9]
82
+ e_hits = X_hit[:,5]
83
+ p_hits = X_hit[:,5]*0
84
+
85
+ if len(output["X_track"])>0:
86
+ pos_xyz_hits_tracks = X_track[:,12:15] #(referencePoint_calo.i)
87
+ pos_xyz_hits = torch.cat((pos_xyz_hits_hits, pos_xyz_hits_tracks), dim=0)
88
+ e_tracks =X_track[:,5]*0
89
+ e = torch.cat((e_hits, e_tracks), dim=0).view(-1,1)
90
+ p_tracks =X_track[:,5]
91
+ pos_pxpypz_hits_tracks = X_track[:,6:9]
92
+ pos_pxpypz = torch.cat((pos_xyz_hits_hits*0, pos_pxpypz_hits_tracks), dim=0)
93
+ pos_pxpypz_hits_tracks = X_track[:,22:]
94
+ pos_pxpypz_calo = torch.cat((pos_xyz_hits_hits*0, pos_pxpypz_hits_tracks), dim=0)
95
+ p = torch.cat((p_hits, p_tracks), dim=0).view(-1,1)
96
+ else:
97
+ pos_xyz_hits = pos_xyz_hits_hits
98
+ e = e_hits.view(-1,1)
99
+ pos_pxpypz = pos_xyz_hits_hits*0
100
+ pos_pxpypz_calo = pos_pxpypz
101
+ p = p_hits.view(-1,1)
102
+
103
+
104
+ if len(output["X_track"])>0:
105
+ chi_tracks = X_track[:,15]/ X_track[:,16]
106
+ chi_squared_tracks = torch.cat((p_hits, chi_tracks), dim=0)
107
+ else:
108
+ chi_squared_tracks = p_hits
109
+ hit_type_one_hot = torch.nn.functional.one_hot(
110
+ hit_type_feature, num_classes=5
111
+ )
112
+
113
+ return cls(
114
+ pos_xyz_hits=pos_xyz_hits,
115
+ pos_pxpypz=pos_pxpypz,
116
+ pos_pxpypz_calo = pos_pxpypz_calo,
117
+ p_hits=p,
118
+ e_hits=e,
119
+ hit_particle_link=hit_particle_link,
120
+ pandora_features= pandora_features,
121
+ hit_type_feature=hit_type_feature,
122
+ chi_squared_tracks=chi_squared_tracks,
123
+ hit_type_one_hot = hit_type_one_hot,
124
+ )
125
+
126
+
src/dataset/dataset.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains a modified version of the dataloader originally from:
3
+
4
+ weaver-core
5
+ https://github.com/hqucms/weaver-core
6
+
7
+ The original implementation has been adapted and extended for the needs of this project.
8
+ Please refer to the original repository for the base implementation and license details.
9
+ Changes in this version:
10
+ - Adapted to read parquet files
11
+ - Modified batching logic to build graphs on the fly
12
+ - No reweighting or standarization of dataset
13
+ """
14
+ import os
15
+ import copy
16
+ import json
17
+ import numpy as np
18
+ import awkward as ak
19
+ import torch.utils.data
20
+ import time
21
+
22
+ from functools import partial
23
+ from concurrent.futures.thread import ThreadPoolExecutor
24
+ from src.data.tools import _pad
25
+ from src.data.fileio import _read_files
26
+ from src.data.preprocess import (
27
+ AutoStandardizer,
28
+ WeightMaker,
29
+ )
30
+ from src.dataset.functions_graph import create_graph
31
+
32
+ def _preprocess(table, options):
33
+ indices = np.arange(
34
+ len(table["X_track"])
35
+ )
36
+ if options["shuffle"]:
37
+ np.random.shuffle(indices)
38
+ return table, indices
39
+
40
+
41
+ def _load_next(filelist, load_range, options):
42
+ table = _read_files(
43
+ filelist, load_range,
44
+ )
45
+ table, indices = _preprocess(table, options)
46
+ return table, indices
47
+
48
+
49
+ class _SimpleIter(object):
50
+ r"""_SimpleIter
51
+ Iterator object for ``SimpleIterDataset''.
52
+ """
53
+
54
+ def __init__(self, **kwargs):
55
+ # inherit all properties from SimpleIterDataset
56
+ self.__dict__.update(**kwargs)
57
+ self.iter_count = 0
58
+
59
+ # executor to read files and run preprocessing asynchronously
60
+ self.executor = ThreadPoolExecutor(max_workers=1) if self._async_load else None
61
+
62
+ # init: prefetch holds table and indices for the next fetch
63
+ self.prefetch = None
64
+ self.table = None
65
+ self.indices = []
66
+ self.cursor = 0
67
+
68
+ self._seed = None
69
+ worker_info = torch.utils.data.get_worker_info()
70
+ file_dict = self._init_file_dict.copy()
71
+ if worker_info is not None:
72
+ # in a worker process
73
+ self._name += "_worker%d" % worker_info.id
74
+ self._seed = worker_info.seed & 0xFFFFFFFF
75
+ np.random.seed(self._seed)
76
+ # split workload by files
77
+ new_file_dict = {}
78
+ for name, files in file_dict.items():
79
+ new_files = files[worker_info.id :: worker_info.num_workers]
80
+ assert len(new_files) > 0
81
+ new_file_dict[name] = new_files
82
+ file_dict = new_file_dict
83
+ self.worker_file_dict = file_dict
84
+ self.worker_filelist = sum(file_dict.values(), [])
85
+ self.worker_info = worker_info
86
+ self.restart()
87
+
88
+ def restart(self):
89
+ print("=== Restarting DataIter %s, seed=%s ===" % (self._name, self._seed))
90
+ # re-shuffle filelist and load range if for training
91
+ filelist = self.worker_filelist.copy()
92
+ if self._sampler_options["shuffle"]:
93
+ np.random.shuffle(filelist)
94
+ if self._file_fraction < 1:
95
+ num_files = int(len(filelist) * self._file_fraction)
96
+ filelist = filelist[:num_files]
97
+ self.filelist = filelist
98
+
99
+ if self._init_load_range_and_fraction is None:
100
+ self.load_range = (0, 1)
101
+ else:
102
+ (start_pos, end_pos), load_frac = self._init_load_range_and_fraction
103
+ interval = (end_pos - start_pos) * load_frac
104
+ if self._sampler_options["shuffle"]:
105
+ offset = np.random.uniform(start_pos, end_pos - interval)
106
+ self.load_range = (offset, offset + interval)
107
+ else:
108
+ self.load_range = (start_pos, start_pos + interval)
109
+
110
+ self.ipos = 0 if self._fetch_by_files else self.load_range[0]
111
+ # prefetch the first entry asynchronously
112
+ self._try_get_next(init=True)
113
+
114
+ def __next__(self):
115
+ graph_empty = True
116
+ self.iter_count += 1
117
+
118
+ while graph_empty:
119
+ if len(self.filelist) == 0:
120
+ raise StopIteration
121
+ try:
122
+ i = self.indices[self.cursor]
123
+ except IndexError:
124
+ # case 1: first entry, `self.indices` is still empty
125
+ # case 2: running out of entries, `self.indices` is not empty
126
+ while True:
127
+ if self.prefetch is None:
128
+ # reaching the end as prefetch got nothing
129
+ self.table = None
130
+ if self._async_load:
131
+ self.executor.shutdown(wait=False)
132
+ raise StopIteration
133
+ # get result from prefetch
134
+ if self._async_load:
135
+ self.table, self.indices = self.prefetch.result()
136
+ else:
137
+ self.table, self.indices = self.prefetch
138
+ # try to load the next ones asynchronously
139
+ self._try_get_next()
140
+ # check if any entries are fetched (i.e., passing selection) -- if not, do another fetch
141
+ if len(self.indices) > 0:
142
+ break
143
+ # reset cursor
144
+ self.cursor = 0
145
+ i = self.indices[self.cursor]
146
+ self.cursor += 1
147
+ data, graph_empty = self.get_data(i)
148
+ return data
149
+
150
+ def _try_get_next(self, init=False):
151
+ end_of_list = (
152
+ self.ipos >= len(self.filelist)
153
+ if self._fetch_by_files
154
+ else self.ipos >= self.load_range[1]
155
+ )
156
+ if end_of_list:
157
+ if init:
158
+ raise RuntimeError(
159
+ "Nothing to load for worker %d" % 0
160
+ if self.worker_info is None
161
+ else self.worker_info.id
162
+ )
163
+ if self._infinity_mode and not self._in_memory:
164
+ # infinity mode: re-start
165
+ self.restart()
166
+ return
167
+ else:
168
+ # finite mode: set prefetch to None, exit
169
+ self.prefetch = None
170
+ return
171
+ if self._fetch_by_files:
172
+ filelist = self.filelist[int(self.ipos) : int(self.ipos + self._fetch_step)]
173
+ load_range = self.load_range
174
+ else:
175
+ filelist = self.filelist
176
+ load_range = (
177
+ self.ipos,
178
+ min(self.ipos + self._fetch_step, self.load_range[1]),
179
+ )
180
+ print('Start fetching next batch, len(filelist)=%d, load_range=%s'%(len(filelist), load_range))
181
+ if self._async_load:
182
+ self.prefetch = self.executor.submit(
183
+ _load_next,
184
+ filelist,
185
+ load_range,
186
+ self._sampler_options,
187
+ )
188
+ else:
189
+ self.prefetch = _load_next(
190
+ filelist, load_range, self._sampler_options
191
+ )
192
+ self.ipos += self._fetch_step
193
+
194
+ def get_data(self, i):
195
+ # inputs
196
+ self.args_parse.prediction = (not self.for_training)
197
+ # X = {k: self.table["_" + k][i].copy() for k in self._data_config.input_names}
198
+ X = {k: self.table[k][i] for k in self.table.fields}
199
+ [g, features_partnn], graph_empty = create_graph(
200
+ X, self.for_training, self.args_parse
201
+ )
202
+
203
+ return [g, features_partnn], graph_empty
204
+ # return X, False
205
+
206
+
207
+ class SimpleIterDataset(torch.utils.data.IterableDataset):
208
+ r"""Base IterableDataset.
209
+ Handles dataloading.
210
+ Arguments:
211
+ file_dict (dict): dictionary of lists of files to be loaded.
212
+ data_config_file (str): YAML file containing data format information.
213
+ for_training (bool): flag indicating whether the dataset is used for training or testing.
214
+ When set to ``True``, will enable shuffling and sampling-based reweighting.
215
+ When set to ``False``, will disable shuffling and reweighting, but will load the observer variables.
216
+ load_range_and_fraction (tuple of tuples, ``((start_pos, end_pos), load_frac)``): fractional range of events to load from each file.
217
+ E.g., setting load_range_and_fraction=((0, 0.8), 0.5) will randomly load 50% out of the first 80% events from each file (so load 50%*80% = 40% of the file).
218
+ fetch_by_files (bool): flag to control how events are retrieved each time we fetch data from disk.
219
+ When set to ``True``, will read only a small number (set by ``fetch_step``) of files each time, but load all the events in these files.
220
+ When set to ``False``, will read from all input files, but load only a small fraction (set by ``fetch_step``) of events each time.
221
+ Default is ``False``, which results in a more uniform sample distribution but reduces the data loading speed.
222
+ fetch_step (float or int): fraction of events (when ``fetch_by_files=False``) or number of files (when ``fetch_by_files=True``) to load each time we fetch data from disk.
223
+ Event shuffling and reweighting (sampling) is performed each time after we fetch data.
224
+ So set this to a large enough value to avoid getting an imbalanced minibatch (due to reweighting/sampling), especially when ``fetch_by_files`` set to ``True``.
225
+ Will load all events (files) at once if set to non-positive value.
226
+ file_fraction (float): fraction of files to load.
227
+ """
228
+
229
+ def __init__(
230
+ self,
231
+ file_dict,
232
+ data_config_file,
233
+ for_training=True,
234
+ load_range_and_fraction=None,
235
+ extra_selection=None,
236
+ fetch_by_files=False,
237
+ fetch_step=0.01,
238
+ file_fraction=1,
239
+ remake_weights=False,
240
+ up_sample=True,
241
+ weight_scale=1,
242
+ max_resample=10,
243
+ async_load=True,
244
+ infinity_mode=False,
245
+ name="",
246
+ args_parse=None
247
+ ):
248
+ self._iters = {} if infinity_mode else None
249
+ _init_args = set(self.__dict__.keys())
250
+ self._init_file_dict = file_dict
251
+ self._init_load_range_and_fraction = load_range_and_fraction
252
+ self._fetch_by_files = fetch_by_files
253
+ self._fetch_step = fetch_step
254
+ self._file_fraction = file_fraction
255
+ self._async_load = async_load
256
+ self._infinity_mode = infinity_mode
257
+ self._name = name
258
+ self.for_training = for_training
259
+ self.args_parse = args_parse
260
+ # ==== sampling parameters ====
261
+ self._sampler_options = {
262
+ "up_sample": up_sample,
263
+ "weight_scale": weight_scale,
264
+ "max_resample": max_resample,
265
+ }
266
+
267
+ if for_training:
268
+ self._sampler_options.update(training=True, shuffle=True, reweight=True)
269
+ else:
270
+ self._sampler_options.update(training=False, shuffle=False, reweight=False)
271
+ self._init_args = set(self.__dict__.keys()) - _init_args
272
+
273
+
274
+
275
+ def __iter__(self):
276
+ if self._iters is None:
277
+ kwargs = {k: copy.deepcopy(self.__dict__[k]) for k in self._init_args}
278
+ return _SimpleIter(**kwargs)
279
+ else:
280
+ worker_info = torch.utils.data.get_worker_info()
281
+ worker_id = worker_info.id if worker_info is not None else 0
282
+ try:
283
+ return self._iters[worker_id]
284
+ except KeyError:
285
+ kwargs = {k: copy.deepcopy(self.__dict__[k]) for k in self._init_args}
286
+ self._iters[worker_id] = _SimpleIter(**kwargs)
287
+ return self._iters[worker_id]
src/dataset/functions_data.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def calculate_distance_to_boundary(g):
6
+ r = 2150
7
+ r_in_endcap = 2307
8
+ mask_endcap = (torch.abs(g.ndata["pos_hits_xyz"][:, 2]) - r_in_endcap) > 0
9
+ mask_barrer = ~mask_endcap
10
+ weight = torch.ones_like(g.ndata["pos_hits_xyz"][:, 0])
11
+ C = g.ndata["pos_hits_xyz"]
12
+ A = torch.tensor([0, 0, 1], dtype=C.dtype, device=C.device)
13
+ P = (
14
+ r
15
+ * 1
16
+ / (torch.norm(torch.cross(A.view(1, -1), C, dim=-1), dim=1)).unsqueeze(1)
17
+ * C
18
+ )
19
+ P1 = torch.abs(r_in_endcap / g.ndata["pos_hits_xyz"][:, 2].unsqueeze(1)) * C
20
+ weight[mask_barrer] = torch.norm(P - C, dim=1)[mask_barrer]
21
+ weight[mask_endcap] = torch.norm(P1[mask_endcap] - C[mask_endcap], dim=1)
22
+ g.ndata["radial_distance"] = weight
23
+ weight_ = torch.exp(-(weight / 1000))
24
+ g.ndata["radial_distance_exp"] = weight_
25
+ return g
26
+
src/dataset/functions_graph.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import dgl
4
+ from src.dataset.functions_data import (
5
+ calculate_distance_to_boundary,
6
+ )
7
+ import time
8
+ from src.dataset.functions_particles import concatenate_Particles_GT, Particles_GT
9
+
10
+ from src.dataset.dataclasses import Hits
11
+
12
+ def create_inputs_from_table(
13
+ output, prediction=False, args=None
14
+ ):
15
+ number_hits = np.int32(len(output["X_track"])+len(output["X_hit"]))
16
+ number_part = np.int32(len(output["X_gen"]))
17
+
18
+ hits = Hits.from_data(
19
+ output,
20
+ number_hits,
21
+ args,
22
+ number_part
23
+ )
24
+
25
+ y_data_graph = Particles_GT()
26
+ y_data_graph.fill( output, prediction,args)
27
+
28
+ result = [
29
+ y_data_graph,
30
+ hits
31
+ ]
32
+ return result
33
+
34
+
35
+
36
+
37
+ def create_graph(
38
+ output,
39
+ for_training =True, args=None
40
+ ):
41
+ prediction = not for_training
42
+ graph_empty = False
43
+
44
+ result = create_inputs_from_table(
45
+ output,
46
+ prediction=prediction,
47
+ args=args
48
+ )
49
+
50
+ if len(result) == 1:
51
+ graph_empty = True
52
+ return [0, 0], graph_empty
53
+ else:
54
+ (y_data_graph,hits) = result
55
+
56
+ g = dgl.graph(([], []))
57
+ g.add_nodes(hits.pos_xyz_hits.shape[0])
58
+ g.ndata["h"] = torch.cat(
59
+ (hits.pos_xyz_hits, hits.hit_type_one_hot, hits.e_hits, hits.p_hits), dim=1
60
+ ).float()
61
+ g.ndata["p_hits"] = hits.p_hits.float()
62
+ g.ndata["pos_hits_xyz"] = hits.pos_xyz_hits.float()
63
+ g.ndata["pos_pxpypz_at_vertex"] = hits.pos_pxpypz.float()
64
+ g.ndata["pos_pxpypz"] = hits.pos_pxpypz #TrackState::AtIP
65
+ g.ndata["pos_pxpypz_at_calo"] = hits.pos_pxpypz_calo #TrackState::AtCalorimeter
66
+ g = calculate_distance_to_boundary(g)
67
+ g.ndata["hit_type"] = hits.hit_type_feature.float()
68
+ g.ndata["e_hits"] = hits.e_hits.float()
69
+
70
+ g.ndata["chi_squared_tracks"] = hits.chi_squared_tracks.float()
71
+ g.ndata["particle_number"] = hits.hit_particle_link.float()+1 #(noise idx is 0 and particle MC 0 starts at 1)
72
+
73
+
74
+ if prediction and (args.pandora):
75
+ g.ndata["pandora_pfo"] = hits.pandora_features.pandora_pfo_link.float()
76
+ g.ndata["pandora_pfo_energy"] = hits.pandora_features.pfo_energy.float()
77
+ g.ndata["pandora_momentum"] = hits.pandora_features.pandora_mom_components.float()
78
+ g.ndata["pandora_reference_point"] = hits.pandora_features.pandora_ref_point.float()
79
+ g.ndata["pandora_pid"] = hits.pandora_features.pandora_pid.float()
80
+ graph_empty = False
81
+ unique_links = torch.unique(hits.hit_particle_link)
82
+ if not prediction and unique_links.shape[0] == 1 and unique_links[0] == -1:
83
+ graph_empty = True
84
+ if hits.pos_xyz_hits.shape[0] < 10:
85
+ graph_empty = True
86
+
87
+ return [g, y_data_graph], graph_empty
88
+
89
+
90
+
91
+
92
+
93
+ def graph_batch_func(list_graphs):
94
+ """collator function for graph dataloader
95
+
96
+ Args:
97
+ list_graphs (list): list of graphs from the iterable dataset
98
+
99
+ Returns:
100
+ batch dgl: dgl batch of graphs
101
+ """
102
+ list_graphs_g = [el[0] for el in list_graphs]
103
+ ys = concatenate_Particles_GT(list_graphs)
104
+ bg = dgl.batch(list_graphs_g)
105
+ return bg, ys
src/dataset/functions_particles.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from sklearn.preprocessing import StandardScaler
4
+ from dataclasses import dataclass
5
+ from typing import Any, List, Optional
6
+
7
+
8
+ @dataclass
9
+ class Particles_GT():
10
+
11
+ angle: Optional[Any] = None
12
+ coord: Optional[Any] = None
13
+ E: Optional[Any] = None
14
+ E_corrected: Optional[Any] = None
15
+ m: Optional[Any] = None
16
+ mass: Optional[Any] = None
17
+ pid: Optional[Any] = None
18
+ vertex: Optional[Any] = None
19
+ gen_status: Optional[Any] = None
20
+ batch_number: Optional[Any] = None
21
+ endpoint: Optional[Any] = None
22
+
23
+ def fill(self, output, prediction, args):
24
+
25
+ features_particles = torch.tensor(output["X_gen"])
26
+ particle_coord_angle = features_particles[:,4:6]
27
+ particle_coord = features_particles[:, 12:15]
28
+ vertex_coord = features_particles[:, 15:18]
29
+
30
+ y_mass = features_particles[:, 10].view(-1).unsqueeze(1)
31
+ y_mom = features_particles[:, 11].view(-1).unsqueeze(1)
32
+ y_energy = features_particles[:, 8].view(-1).unsqueeze(1)
33
+ y_pid = features_particles[:,0]
34
+ gen_status = features_particles[:,1]
35
+
36
+ self.angle= particle_coord_angle
37
+ self.coord = particle_coord
38
+ self.E_corrected = y_energy
39
+ self.E = y_energy
40
+ self.m = y_mom
41
+ self.mass = y_mass
42
+ self.pid = y_pid
43
+ self.vertex=vertex_coord
44
+ self.gen_status = gen_status
45
+
46
+
47
+ def __len__(self):
48
+ return len(self.E)
49
+
50
+ def mask(self, mask):
51
+ for k in self.__dict__:
52
+ if getattr(self, k) is not None:
53
+ if type(getattr(self, k)) == list:
54
+ if getattr(self, k)[0] is not None:
55
+ setattr(self, k, getattr(self, k)[mask])
56
+ else:
57
+ setattr(self, k, getattr(self, k)[mask])
58
+
59
+ def copy(self):
60
+ obj = type(self).__new__(self.__class__)
61
+ obj.__dict__.update(self.__dict__)
62
+ return obj
63
+
64
+
65
+
66
+ def concatenate_Particles_GT(list_of_Particles_GT):
67
+ list_coord = [p[1].coord for p in list_of_Particles_GT]
68
+ list_angle = [p[1].angle for p in list_of_Particles_GT]
69
+ list_angle = torch.cat(list_angle, dim=0)
70
+ list_vertex = [p[1].vertex for p in list_of_Particles_GT]
71
+ list_coord = torch.cat(list_coord, dim=0)
72
+ list_E = [p[1].E for p in list_of_Particles_GT]
73
+ list_E = torch.cat(list_E, dim=0)
74
+ list_E_corr = [p[1].E_corrected for p in list_of_Particles_GT]
75
+ list_E_corr = torch.cat(list_E_corr, dim=0)
76
+ list_m = [p[1].m for p in list_of_Particles_GT]
77
+ list_m = torch.cat(list_m, dim=0)
78
+ list_mass = [p[1].mass for p in list_of_Particles_GT]
79
+ list_mass = torch.cat(list_mass, dim=0)
80
+ list_pid = [p[1].pid for p in list_of_Particles_GT]
81
+ list_pid = torch.cat(list_pid, dim=0)
82
+ list_genstatus = [p[1].gen_status for p in list_of_Particles_GT]
83
+ list_genstatus = torch.cat(list_genstatus, dim=0)
84
+ if hasattr(list_of_Particles_GT[0], "endpoint"):
85
+ list_endpoint = [p[1].endpoint for p in list_of_Particles_GT]
86
+ list_endpoint= torch.cat(list_endpoint, dim=0)
87
+ else:
88
+ list_endpoint = None
89
+ if list_vertex[0] is not None:
90
+ list_vertex = torch.cat(list_vertex, dim=0)
91
+ if hasattr(list_of_Particles_GT[0], "decayed_in_calo"):
92
+ list_dec_calo = [p[1].decayed_in_calo for p in list_of_Particles_GT]
93
+ list_dec_track = [p[1].decayed_in_tracker for p in list_of_Particles_GT]
94
+ list_dec_calo = torch.cat(list_dec_calo, dim=0)
95
+ list_dec_track = torch.cat(list_dec_track, dim=0)
96
+ else:
97
+ list_dec_calo = None
98
+ list_dec_track = None
99
+ batch_number = add_batch_number(list_of_Particles_GT)
100
+ particle_batch = Particles_GT()
101
+ particle_batch.angle = list_angle
102
+ particle_batch.coord = list_coord
103
+ particle_batch.E = list_E
104
+ particle_batch.E_corrected = list_E_corr
105
+ particle_batch.m = list_m
106
+ particle_batch.pid = list_pid
107
+ particle_batch.vertex= list_vertex
108
+ particle_batch.decayed_in_calo = list_dec_calo
109
+ particle_batch.decayed_in_tracker = list_dec_track
110
+ particle_batch.batch_number = batch_number
111
+ particle_batch.gen_status = list_genstatus
112
+ particle_batch.endpoint = list_endpoint
113
+ return particle_batch
114
+
115
+ def add_batch_number(list_graphs):
116
+ list_y = []
117
+ for i, el in enumerate(list_graphs):
118
+ y = el[1]
119
+ batch_id = torch.ones(y.E.shape[0], 1) * i
120
+ list_y.append(batch_id)
121
+ list_y = torch.cat(list_y, dim=0)
122
+ return list_y
src/inference.py ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Standalone single-event MLPF inference.
3
+
4
+ Provides :func:`run_single_event_inference` which takes raw event data
5
+ (from a parquet file or as an awkward record) and model checkpoint paths,
6
+ runs the full particle-flow pipeline (graph construction → GATr forward
7
+ pass → density-peak clustering → energy correction & PID), and returns:
8
+
9
+ * a ``pandas.DataFrame`` of predicted particles with their properties
10
+ * a hit→cluster mapping as a ``pandas.DataFrame``
11
+ """
12
+
13
+ import argparse
14
+ import types
15
+ from typing import Optional
16
+ import numpy as np
17
+ import pandas as pd
18
+ import torch
19
+ import dgl
20
+ import awkward as ak
21
+
22
+ from src.data.fileio import _read_parquet
23
+ from src.dataset.functions_graph import create_graph
24
+ from src.dataset.functions_particles import Particles_GT, add_batch_number
25
+ from src.layers.clustering import DPC_custom_CLD, remove_bad_tracks_from_cluster
26
+ from src.utils.pid_conversion import pid_conversion_dict
27
+
28
+
29
+ # -- CPU-compatible attention patch ------------------------------------------
30
+
31
+ def _patch_gatr_attention_for_cpu():
32
+ """Replace GATr's xformers-based attention with a naive implementation.
33
+
34
+ ``xformers.ops.fmha.memory_efficient_attention`` has no CPU kernel, so
35
+ running GATr on CPU crashes. This function monkey-patches
36
+ ``gatr.primitives.attention.scaled_dot_product_attention`` with a plain
37
+ PyTorch implementation that works on any device (albeit slower on GPU).
38
+ The patch is applied at most once.
39
+ """
40
+ import gatr.primitives.attention as _gatr_attn
41
+
42
+ if getattr(_gatr_attn, "_cpu_patched", False):
43
+ return
44
+
45
+ def _cpu_sdpa(q, k, v, attn_mask=None):
46
+ # q, k, v: (B, H, N, D) — batch, heads, items, dim
47
+ B, H, N, D = q.shape
48
+ scale = float(D) ** -0.5
49
+
50
+ q2 = q.reshape(B * H, N, D)
51
+ k2 = k.reshape(B * H, N, D)
52
+ v2 = v.reshape(B * H, N, D)
53
+
54
+ attn = torch.bmm(q2 * scale, k2.transpose(1, 2)) # (B*H, N, N)
55
+
56
+ if attn_mask is not None:
57
+ dense = _block_diag_mask_to_dense(attn_mask, N, q.device)
58
+ if dense is not None:
59
+ attn = attn.masked_fill(~dense.unsqueeze(0), float("-inf"))
60
+
61
+ attn = torch.softmax(attn, dim=-1)
62
+ # Rows that are fully masked produce NaN after softmax; zero them out.
63
+ attn = attn.nan_to_num(0.0)
64
+
65
+ out = torch.bmm(attn, v2) # (B*H, N, D)
66
+ return out.reshape(B, H, N, D)
67
+
68
+ _gatr_attn.scaled_dot_product_attention = _cpu_sdpa
69
+ _gatr_attn._cpu_patched = True
70
+
71
+
72
+ def _block_diag_mask_to_dense(attn_mask, total_len, device):
73
+ """Convert an ``xformers.ops.fmha.BlockDiagonalMask`` to a dense bool mask."""
74
+ try:
75
+ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
76
+ if not isinstance(attn_mask, BlockDiagonalMask):
77
+ return None
78
+ except ImportError:
79
+ return None
80
+
81
+ # Extract per-sequence start offsets
82
+ try:
83
+ seqstarts = attn_mask.q_seqinfo.seqstart_py
84
+ except AttributeError:
85
+ try:
86
+ seqstarts = attn_mask.q_seqinfo.seqstart.cpu().tolist()
87
+ except Exception:
88
+ return None
89
+
90
+ mask = torch.zeros(total_len, total_len, dtype=torch.bool, device=device)
91
+ for i in range(len(seqstarts) - 1):
92
+ s, e = seqstarts[i], seqstarts[i + 1]
93
+ mask[s:e, s:e] = True
94
+ return mask
95
+
96
+
97
+ # -- PID label → human-readable name ----------------------------------------
98
+
99
+ _PID_LABELS = {
100
+ 0: "electron",
101
+ 1: "charged hadron",
102
+ 2: "neutral hadron",
103
+ 3: "photon",
104
+ 4: "muon",
105
+ }
106
+
107
+ _ABS_PDG_NAME = {
108
+ 11: "electron",
109
+ 13: "muon",
110
+ 22: "photon",
111
+ 130: "K_L",
112
+ 211: "pion±",
113
+ 321: "kaon±",
114
+ 2112: "neutron",
115
+ 2212: "proton",
116
+ 310: "K_S",
117
+ }
118
+
119
+
120
+ # -- Minimal args namespace for inference ------------------------------------
121
+
122
+ def _default_args(**overrides):
123
+ """Return a minimal ``argparse.Namespace`` with defaults the model expects."""
124
+ d = dict(
125
+ correction=True,
126
+ freeze_clustering=True,
127
+ predict=True,
128
+ pandora=False,
129
+ use_gt_clusters=False,
130
+ use_average_cc_pos=0.99,
131
+ qmin=1.0,
132
+ data_config="config_files/config_hits_track_v4.yaml",
133
+ network_config="src/models/wrapper/example_mode_gatr_noise.py",
134
+ model_prefix="/tmp/mlpf_eval",
135
+ start_lr=1e-3,
136
+ frac_cluster_loss=0,
137
+ local_rank=0,
138
+ gpus="0",
139
+ batch_size=1,
140
+ num_workers=0,
141
+ prefetch_factor=1,
142
+ num_epochs=1,
143
+ steps_per_epoch=None,
144
+ samples_per_epoch=None,
145
+ steps_per_epoch_val=None,
146
+ samples_per_epoch_val=None,
147
+ train_val_split=0.8,
148
+ data_train=[],
149
+ data_val=[],
150
+ data_test=[],
151
+ data_fraction=1,
152
+ file_fraction=1,
153
+ fetch_by_files=True,
154
+ fetch_step=1,
155
+ log_wandb=False,
156
+ wandb_displayname="",
157
+ wandb_projectname="",
158
+ wandb_entity="",
159
+ name_output="gradio",
160
+ train_batches=100,
161
+ )
162
+ d.update(overrides)
163
+ return argparse.Namespace(**d)
164
+
165
+
166
+ # -- Model loading -----------------------------------------------------------
167
+
168
+ def load_model(
169
+ clustering_ckpt: str,
170
+ energy_pid_ckpt: Optional[str] = None,
171
+ device: str = "cpu",
172
+ args_overrides: Optional[dict] = None,
173
+ ):
174
+ """Load the full MLPF model (clustering + optional energy/PID correction).
175
+
176
+ Parameters
177
+ ----------
178
+ clustering_ckpt : str
179
+ Path to the clustering checkpoint (``.ckpt``).
180
+ energy_pid_ckpt : str or None
181
+ Path to the energy-correction / PID checkpoint (``.ckpt``).
182
+ If *None*, only clustering is performed (no energy correction / PID).
183
+ device : str
184
+ ``"cpu"`` or ``"cuda:0"`` etc.
185
+ args_overrides : dict or None
186
+ Extra key-value pairs forwarded to :func:`_default_args`.
187
+
188
+ Returns
189
+ -------
190
+ model : ExampleWrapper
191
+ The model in eval mode, on *device*.
192
+ args : argparse.Namespace
193
+ The arguments namespace used.
194
+ """
195
+ from src.models.Gatr_pf_e_noise import ExampleWrapper
196
+
197
+ overrides = dict(args_overrides or {})
198
+ has_correction = energy_pid_ckpt is not None
199
+ overrides["correction"] = has_correction
200
+
201
+ args = _default_args(**overrides)
202
+ dev = torch.device(device)
203
+
204
+ if has_correction:
205
+ ckpt = torch.load(energy_pid_ckpt, map_location=dev)
206
+ state_dict = ckpt["state_dict"]
207
+ model = ExampleWrapper(args=args, dev=0)
208
+ model.load_state_dict(state_dict, strict=False)
209
+ # Overwrite clustering layers from clustering checkpoint
210
+ model2 = ExampleWrapper.load_from_checkpoint(
211
+ clustering_ckpt, args=args, dev=0, strict=False, map_location=dev,
212
+ )
213
+ model.gatr = model2.gatr
214
+ model.ScaledGooeyBatchNorm2_1 = model2.ScaledGooeyBatchNorm2_1
215
+ model.clustering = model2.clustering
216
+ model.beta = model2.beta
217
+ else:
218
+ model = ExampleWrapper.load_from_checkpoint(
219
+ clustering_ckpt, args=args, dev=0, strict=False, map_location=dev,
220
+ )
221
+
222
+ model = model.to(dev)
223
+ model.eval()
224
+ return model, args
225
+
226
+
227
+ def load_random_model(
228
+ device: str = "cpu",
229
+ args_overrides: Optional[dict] = None,
230
+ ):
231
+ """Create a GATr model with randomly initialised weights (no checkpoint).
232
+
233
+ This is useful for debugging to verify that checkpoint weights are
234
+ actually being loaded and used by the model.
235
+
236
+ Parameters
237
+ ----------
238
+ device : str
239
+ ``"cpu"`` or ``"cuda:0"`` etc.
240
+ args_overrides : dict or None
241
+ Extra key-value pairs forwarded to :func:`_default_args`.
242
+
243
+ Returns
244
+ -------
245
+ model : ExampleWrapper
246
+ The model (random weights) in eval mode, on *device*.
247
+ args : argparse.Namespace
248
+ The arguments namespace used.
249
+ """
250
+ from src.models.Gatr_pf_e_noise import ExampleWrapper
251
+
252
+ overrides = dict(args_overrides or {})
253
+ overrides["correction"] = False
254
+
255
+ args = _default_args(**overrides)
256
+ dev = torch.device(device)
257
+
258
+ model = ExampleWrapper(args=args, dev=0)
259
+ model = model.to(dev)
260
+ model.eval()
261
+ return model, args
262
+
263
+
264
+ # -- Single-event data loading -----------------------------------------------
265
+
266
+ def load_event_from_parquet(parquet_path: str, event_index: int = 0):
267
+ """Read a single event from a parquet file.
268
+
269
+ Returns an awkward record with fields ``X_hit``, ``X_track``, ``X_gen``,
270
+ ``ygen_hit``, ``ygen_track``, etc.
271
+ """
272
+ table = _read_parquet(parquet_path)
273
+ n_events = len(table["X_track"])
274
+ if event_index >= n_events:
275
+ raise IndexError(
276
+ f"event_index {event_index} out of range (file has {n_events} events)"
277
+ )
278
+ event = {field: table[field][event_index] for field in table.fields}
279
+ return event
280
+
281
+
282
+ # -- Core inference function --------------------------------------------------
283
+
284
+ @torch.no_grad()
285
+ def run_single_event_inference(
286
+ event,
287
+ model,
288
+ args,
289
+ device: str = "cpu",
290
+ ):
291
+ """Run full MLPF inference on a single event.
292
+
293
+ Parameters
294
+ ----------
295
+ event : dict-like
296
+ A single event record (from :func:`load_event_from_parquet`).
297
+ model : ExampleWrapper
298
+ The loaded model (from :func:`load_model`).
299
+ args : argparse.Namespace
300
+ The arguments namespace (from :func:`load_model`).
301
+ device : str
302
+ Device string.
303
+
304
+ Returns
305
+ -------
306
+ particles_df : pandas.DataFrame
307
+ One row per predicted particle with columns:
308
+ ``cluster_id``, ``energy``, ``pid_class``, ``pid_label``,
309
+ ``px``, ``py``, ``pz``, ``is_charged``.
310
+ hit_cluster_df : pandas.DataFrame
311
+ One row per hit with columns:
312
+ ``hit_index``, ``cluster_id``, ``pandora_cluster_id``,
313
+ ``hit_type_id``, ``hit_type``, ``x``, ``y``, ``z``,
314
+ ``hit_energy``, ``cluster_x``, ``cluster_y``, ``cluster_z``.
315
+ ``pandora_cluster_id`` is -1 when pandora data is not available
316
+ or when the hit has no matching entry (e.g. CSV was modified after
317
+ loading from parquet).
318
+ mc_particles_df : pandas.DataFrame
319
+ One row per MC truth particle with columns:
320
+ ``pid``, ``energy``, ``momentum``, ``px``, ``py``, ``pz``,
321
+ ``mass``, ``theta``, ``phi``, ``vx``, ``vy``, ``vz``,
322
+ ``gen_status``, ``pdg_name``.
323
+ pandora_particles_df : pandas.DataFrame
324
+ One row per Pandora PFO with columns:
325
+ ``pfo_idx``, ``pid``, ``pdg_name``, ``energy``, ``momentum``,
326
+ ``px``, ``py``, ``pz``, ``ref_x``, ``ref_y``, ``ref_z``.
327
+ Empty when pandora data is not available in the input.
328
+ """
329
+ dev = torch.device(device)
330
+
331
+ # Ensure eval mode so that BatchNorm layers use running statistics from
332
+ # training instead of computing batch statistics from the current
333
+ # (single-event) input. Without this, inference with batch_size=1
334
+ # produces incorrect normalization.
335
+ model.eval()
336
+
337
+ if dev.type == "cpu":
338
+ _patch_gatr_attention_for_cpu()
339
+
340
+ # 0. Extract MC truth particles table and pandora particles
341
+ mc_particles_df = _extract_mc_particles(event)
342
+ pandora_particles_df, pfo_calohit, pfo_track = _extract_pandora_particles(event)
343
+
344
+ # 1. Build DGL graph from the event
345
+ [g, y_data], graph_empty = create_graph(event, for_training=False, args=args)
346
+ if graph_empty:
347
+ return pd.DataFrame(), pd.DataFrame(), mc_particles_df, pandora_particles_df
348
+
349
+ g = g.to(dev)
350
+ # Prepare batch metadata expected by the model
351
+ y_data.batch_number = torch.zeros(y_data.E.shape[0], 1)
352
+
353
+ # 2. Forward pass through the GATr clustering backbone
354
+ inputs = g.ndata["pos_hits_xyz"].float().to(dev)
355
+ inputs_scalar = g.ndata["hit_type"].float().view(-1, 1).to(dev)
356
+
357
+ from gatr.interface import embed_point, embed_scalar
358
+ from xformers.ops.fmha import BlockDiagonalMask
359
+
360
+ inputs_normed = model.ScaledGooeyBatchNorm2_1(inputs)
361
+ embedded_inputs = embed_point(inputs_normed) + embed_scalar(inputs_scalar)
362
+ embedded_inputs = embedded_inputs.unsqueeze(-2)
363
+ mask = BlockDiagonalMask.from_seqlens([g.num_nodes()])
364
+ scalars = torch.cat(
365
+ (g.ndata["e_hits"].float().to(dev), g.ndata["p_hits"].float().to(dev)), dim=1
366
+ )
367
+
368
+ from gatr.interface import extract_point, extract_scalar
369
+
370
+ embedded_outputs, scalar_outputs = model.gatr(
371
+ embedded_inputs, scalars=scalars, attention_mask=mask
372
+ )
373
+ points = extract_point(embedded_outputs[:, 0, :])
374
+ nodewise_outputs = extract_scalar(embedded_outputs)
375
+ x_point = points
376
+ x_scalar = torch.cat(
377
+ (nodewise_outputs.view(-1, 1), scalar_outputs.view(-1, 1)), dim=1
378
+ )
379
+ x_cluster_coord = model.clustering(x_point)
380
+ beta = model.beta(x_scalar)
381
+
382
+ g.ndata["final_cluster"] = x_cluster_coord
383
+ g.ndata["beta"] = beta.view(-1)
384
+
385
+ # 3. Density-peak clustering
386
+ labels = DPC_custom_CLD(x_cluster_coord, g, dev)
387
+ labels, _ = remove_bad_tracks_from_cluster(g, labels)
388
+
389
+ # 4. Build hit→cluster table
390
+ n_hits = g.num_nodes()
391
+ hit_types_raw = g.ndata["hit_type"].cpu().numpy()
392
+ hit_type_names = {1: "track", 2: "ECAL", 3: "HCAL", 4: "muon"}
393
+
394
+ # Build pandora cluster ID per node (hits first, then tracks)
395
+ # Use min of array lengths for graceful handling when CSV was modified
396
+ n_calo = len(np.asarray(event.get("X_hit", [])))
397
+ pandora_cluster_ids = np.full(n_hits, -1, dtype=np.int64)
398
+ if len(pfo_calohit) > 0:
399
+ n_assign = min(len(pfo_calohit), n_calo)
400
+ pandora_cluster_ids[:n_assign] = pfo_calohit[:n_assign]
401
+ n_tracks = n_hits - n_calo
402
+ if n_tracks > 0 and len(pfo_track) > 0:
403
+ n_assign = min(len(pfo_track), n_tracks)
404
+ pandora_cluster_ids[n_calo:n_calo + n_assign] = pfo_track[:n_assign]
405
+
406
+ hit_cluster_df = pd.DataFrame({
407
+ "hit_index": np.arange(n_hits),
408
+ "cluster_id": labels.cpu().numpy(),
409
+ "pandora_cluster_id": pandora_cluster_ids,
410
+ "hit_type_id": hit_types_raw,
411
+ "hit_type": [hit_type_names.get(int(t), str(int(t))) for t in hit_types_raw],
412
+ "x": g.ndata["pos_hits_xyz"][:, 0].cpu().numpy(),
413
+ "y": g.ndata["pos_hits_xyz"][:, 1].cpu().numpy(),
414
+ "z": g.ndata["pos_hits_xyz"][:, 2].cpu().numpy(),
415
+ "hit_energy": g.ndata["e_hits"].view(-1).cpu().numpy(),
416
+ "cluster_x": x_cluster_coord[:, 0].cpu().numpy(),
417
+ "cluster_y": x_cluster_coord[:, 1].cpu().numpy(),
418
+ "cluster_z": x_cluster_coord[:, 2].cpu().numpy(),
419
+ })
420
+
421
+ # 5. Per-cluster summary (basic, before energy correction)
422
+ unique_labels = torch.unique(labels)
423
+ # cluster 0 = noise
424
+ cluster_ids = unique_labels[unique_labels > 0].cpu().numpy()
425
+
426
+ from torch_scatter import scatter_add
427
+
428
+ e_per_cluster = scatter_add(
429
+ g.ndata["e_hits"].view(-1).to(dev), labels.to(dev)
430
+ )
431
+ p_per_cluster = scatter_add(
432
+ g.ndata["p_hits"].view(-1).to(dev), labels.to(dev)
433
+ )
434
+ n_hits_per_cluster = scatter_add(
435
+ torch.ones(n_hits, device=dev), labels.to(dev)
436
+ )
437
+ # Check if any cluster has a track (→ charged)
438
+ is_track_per_cluster = scatter_add(
439
+ (g.ndata["hit_type"].to(dev) == 1).float(), labels.to(dev)
440
+ )
441
+
442
+ rows = []
443
+ for cid in cluster_ids:
444
+ mask_c = labels == cid
445
+ e_sum = e_per_cluster[cid].item()
446
+ p_sum = p_per_cluster[cid].item()
447
+ n_h = int(n_hits_per_cluster[cid].item())
448
+ has_track = is_track_per_cluster[cid].item() >= 1
449
+ # Mean position
450
+ pos_mean = g.ndata["pos_hits_xyz"][mask_c].mean(dim=0).cpu().numpy()
451
+ rows.append({
452
+ "cluster_id": int(cid),
453
+ "energy_sum_hits": round(e_sum, 4),
454
+ "p_track": round(p_sum, 4) if has_track else 0.0,
455
+ "n_hits": n_h,
456
+ "is_charged": has_track,
457
+ "mean_x": round(float(pos_mean[0]), 2),
458
+ "mean_y": round(float(pos_mean[1]), 2),
459
+ "mean_z": round(float(pos_mean[2]), 2),
460
+ })
461
+
462
+ particles_df = pd.DataFrame(rows)
463
+
464
+ # 6. If energy correction is available, run it
465
+ if args.correction and hasattr(model, "energy_correction"):
466
+ try:
467
+ particles_df = _run_energy_correction(
468
+ model, g, x_cluster_coord, beta, labels, y_data, particles_df, dev
469
+ )
470
+ except Exception as e:
471
+ # Attach a note but don't crash – the basic table is still useful
472
+ particles_df["note"] = f"Energy correction failed: {e}"
473
+
474
+ return particles_df, hit_cluster_df, mc_particles_df, pandora_particles_df
475
+
476
+
477
+ def _extract_mc_particles(event):
478
+ """Build a DataFrame of MC truth particles from the event's ``X_gen``."""
479
+ x_gen = np.asarray(event.get("X_gen", []))
480
+ if x_gen.ndim != 2 or x_gen.shape[0] == 0 or x_gen.shape[1] < 18:
481
+ return pd.DataFrame()
482
+
483
+ rows = []
484
+ for i in range(x_gen.shape[0]):
485
+ pid_raw = int(x_gen[i, 0])
486
+ rows.append({
487
+ "particle_idx": i,
488
+ "pid": pid_raw,
489
+ "pdg_name": _ABS_PDG_NAME.get(abs(pid_raw), str(pid_raw)),
490
+ "gen_status": int(x_gen[i, 1]),
491
+ "energy": round(float(x_gen[i, 8]), 4),
492
+ "momentum": round(float(x_gen[i, 11]), 4),
493
+ "px": round(float(x_gen[i, 12]), 4),
494
+ "py": round(float(x_gen[i, 13]), 4),
495
+ "pz": round(float(x_gen[i, 14]), 4),
496
+ "mass": round(float(x_gen[i, 10]), 4),
497
+ "theta": round(float(x_gen[i, 4]), 4),
498
+ "phi": round(float(x_gen[i, 5]), 4),
499
+ "vx": round(float(x_gen[i, 15]), 4),
500
+ "vy": round(float(x_gen[i, 16]), 4),
501
+ "vz": round(float(x_gen[i, 17]), 4),
502
+ })
503
+ return pd.DataFrame(rows)
504
+
505
+
506
+ def _extract_pandora_particles(event):
507
+ """Build a DataFrame of Pandora PFO particles from the event's ``X_pandora``.
508
+
509
+ ``X_pandora`` columns (per PFO):
510
+ 0: pid (PDG ID)
511
+ 1–3: px, py, pz (momentum components at reference point)
512
+ 4–6: ref_x, ref_y, ref_z (reference point)
513
+ 7: energy
514
+ 8: momentum magnitude
515
+
516
+ Returns (pandora_particles_df, pfo_hit_links, pfo_track_links) where
517
+ *pfo_hit_links* and *pfo_track_links* are integer arrays mapping each
518
+ hit/track to a PFO index (0-based, -1 = unassigned).
519
+ """
520
+ x_pandora = np.asarray(event.get("X_pandora", []))
521
+ pfo_calohit = np.asarray(event.get("pfo_calohit", []), dtype=np.int64)
522
+ pfo_track = np.asarray(event.get("pfo_track", []), dtype=np.int64)
523
+
524
+ if x_pandora.ndim != 2 or x_pandora.shape[0] == 0 or x_pandora.shape[1] < 9:
525
+ return pd.DataFrame(), pfo_calohit, pfo_track
526
+
527
+ rows = []
528
+ for i in range(x_pandora.shape[0]):
529
+ pid_raw = int(x_pandora[i, 0])
530
+ rows.append({
531
+ "pfo_idx": i,
532
+ "pid": pid_raw,
533
+ "pdg_name": _ABS_PDG_NAME.get(abs(pid_raw), str(pid_raw)),
534
+ "energy": round(float(x_pandora[i, 7]), 4),
535
+ "momentum": round(float(x_pandora[i, 8]), 4),
536
+ "px": round(float(x_pandora[i, 1]), 4),
537
+ "py": round(float(x_pandora[i, 2]), 4),
538
+ "pz": round(float(x_pandora[i, 3]), 4),
539
+ "ref_x": round(float(x_pandora[i, 4]), 2),
540
+ "ref_y": round(float(x_pandora[i, 5]), 2),
541
+ "ref_z": round(float(x_pandora[i, 6]), 2),
542
+ })
543
+ return pd.DataFrame(rows), pfo_calohit, pfo_track
544
+
545
+
546
+ def _run_energy_correction(model, g, x_cluster_coord, beta, labels, y_data, particles_df, dev):
547
+ """Run the energy correction & PID branch and enrich *particles_df*."""
548
+ from src.layers.shower_matching import match_showers, obtain_intersection_matrix, obtain_union_matrix
549
+ from torch_scatter import scatter_add, scatter_mean
550
+ from src.utils.post_clustering_features import (
551
+ get_post_clustering_features, get_extra_features, calculate_eta, calculate_phi,
552
+ )
553
+
554
+ x = torch.cat((x_cluster_coord, beta.view(-1, 1)), dim=1)
555
+
556
+ # Re-create per-cluster sub-graphs expected by the correction pipeline
557
+ particle_ids = torch.unique(g.ndata["particle_number"])
558
+ shower_p_unique = torch.unique(labels)
559
+ model_output_dummy = x # used only for device by match_showers
560
+
561
+ shower_p_unique_m, row_ind, col_ind, i_m_w, _ = match_showers(
562
+ labels, {"graph": g, "part_true": y_data},
563
+ particle_ids, model_output_dummy, 0, 0, None,
564
+ )
565
+ row_ind = torch.Tensor(row_ind).to(dev).long()
566
+ col_ind = torch.Tensor(col_ind).to(dev).long()
567
+ if torch.sum(particle_ids == 0) > 0:
568
+ row_ind_ = row_ind - 1
569
+ else:
570
+ row_ind_ = row_ind
571
+ index_matches = (col_ind + 1).to(dev).long()
572
+
573
+ # Build per-cluster sub-graphs (matched + fakes)
574
+ graphs_matched = []
575
+ true_energies = []
576
+ reco_energies = []
577
+ pids_matched = []
578
+ coords_matched = []
579
+ e_true_daughters = []
580
+
581
+ for j, sh_label in enumerate(index_matches):
582
+ if torch.sum(sh_label == index_matches) == 1:
583
+ mask = labels == sh_label
584
+ sg = dgl.graph(([], []))
585
+ sg.add_nodes(int(mask.sum()))
586
+ sg = sg.to(dev)
587
+ sg.ndata["h"] = g.ndata["h"][mask]
588
+ if "pos_pxpypz" in g.ndata:
589
+ sg.ndata["pos_pxpypz"] = g.ndata["pos_pxpypz"][mask]
590
+ if "pos_pxpypz_at_vertex" in g.ndata:
591
+ sg.ndata["pos_pxpypz_at_vertex"] = g.ndata["pos_pxpypz_at_vertex"][mask]
592
+ sg.ndata["chi_squared_tracks"] = g.ndata["chi_squared_tracks"][mask]
593
+ energy_t = y_data.E.to(dev)
594
+ true_e = energy_t[row_ind_[j]]
595
+ pids_matched.append(y_data.pid[row_ind_[j]].item())
596
+ coords_matched.append(y_data.coord[row_ind_[j]].detach().cpu().numpy())
597
+ e_true_daughters.append(y_data.m[row_ind_[j]].to(dev))
598
+ reco_e = torch.sum(g.ndata["e_hits"].view(-1).to(dev)[mask])
599
+ graphs_matched.append(sg)
600
+ true_energies.append(true_e.view(-1))
601
+ reco_energies.append(reco_e.view(-1))
602
+
603
+ # Add fakes
604
+ pred_showers = shower_p_unique_m.clone()
605
+ pred_showers[index_matches] = -1
606
+ pred_showers[0] = -1
607
+ fakes_mask = pred_showers != -1
608
+ fakes_idx = torch.where(fakes_mask)[0]
609
+
610
+ graphs_fakes = []
611
+ reco_fakes = []
612
+ for fi in fakes_idx:
613
+ mask = labels == fi
614
+ sg = dgl.graph(([], []))
615
+ sg.add_nodes(int(mask.sum()))
616
+ sg = sg.to(dev)
617
+ sg.ndata["h"] = g.ndata["h"][mask]
618
+ if "pos_pxpypz" in g.ndata:
619
+ sg.ndata["pos_pxpypz"] = g.ndata["pos_pxpypz"][mask]
620
+ if "pos_pxpypz_at_vertex" in g.ndata:
621
+ sg.ndata["pos_pxpypz_at_vertex"] = g.ndata["pos_pxpypz_at_vertex"][mask]
622
+ sg.ndata["chi_squared_tracks"] = g.ndata["chi_squared_tracks"][mask]
623
+ graphs_fakes.append(sg)
624
+ reco_fakes.append(torch.sum(g.ndata["e_hits"].view(-1).to(dev)[mask]).view(-1))
625
+
626
+ if not graphs_matched and not graphs_fakes:
627
+ return particles_df
628
+
629
+ all_graphs = dgl.batch(graphs_matched + graphs_fakes)
630
+ sum_e = torch.cat(reco_energies + reco_fakes, dim=0)
631
+
632
+ # Compute high-level features
633
+ batch_num_nodes = all_graphs.batch_num_nodes()
634
+ batch_idx = []
635
+ for i, n in enumerate(batch_num_nodes):
636
+ batch_idx.extend([i] * n)
637
+ batch_idx = torch.tensor(batch_idx).to(dev)
638
+
639
+ all_graphs.ndata["h"][:, 0:3] = all_graphs.ndata["h"][:, 0:3] / 3300
640
+ graphs_sum_features = scatter_add(all_graphs.ndata["h"], batch_idx, dim=0)
641
+ graphs_sum_features = graphs_sum_features[batch_idx]
642
+ betas = torch.sigmoid(all_graphs.ndata["h"][:, -1])
643
+ all_graphs.ndata["h"] = torch.cat(
644
+ (all_graphs.ndata["h"], graphs_sum_features), dim=1
645
+ )
646
+
647
+ high_level = get_post_clustering_features(all_graphs, sum_e)
648
+ extra_features = get_extra_features(all_graphs, betas)
649
+
650
+ n_clusters = high_level.shape[0]
651
+ pred_energy = torch.ones(n_clusters, device=dev)
652
+ pred_pos = torch.ones(n_clusters, 3, device=dev)
653
+ pred_pid = torch.ones(n_clusters, device=dev).long()
654
+
655
+ node_features_avg = scatter_mean(all_graphs.ndata["h"], batch_idx, dim=0)[:, 0:3]
656
+ eta = calculate_eta(node_features_avg[:, 0], node_features_avg[:, 1], node_features_avg[:, 2])
657
+ phi = calculate_phi(node_features_avg[:, 0], node_features_avg[:, 1])
658
+ high_level = torch.cat(
659
+ (high_level, node_features_avg, eta.view(-1, 1), phi.view(-1, 1)), dim=1
660
+ )
661
+
662
+ num_tracks = high_level[:, 7]
663
+ charged_idx = torch.where(num_tracks >= 1)[0]
664
+ neutral_idx = torch.where(num_tracks < 1)[0]
665
+
666
+ def zero_nans(t):
667
+ out = t.clone()
668
+ out[out != out] = 0
669
+ return out
670
+
671
+ feats_charged = zero_nans(high_level[charged_idx])
672
+ feats_neutral = zero_nans(high_level[neutral_idx])
673
+
674
+ # Run charged prediction
675
+ charged_energies = model.energy_correction.model_charged.charged_prediction(
676
+ all_graphs, charged_idx, feats_charged,
677
+ )
678
+ # Run neutral prediction
679
+ neutral_energies, neutral_pxyz_avg = model.energy_correction.model_neutral.neutral_prediction(
680
+ all_graphs, neutral_idx, feats_neutral,
681
+ )
682
+
683
+ pids_charged = model.energy_correction.pids_charged
684
+ pids_neutral = model.energy_correction.pids_neutral
685
+
686
+ if len(pids_charged):
687
+ ch_e, ch_pos, ch_pid_logits, ch_ref = charged_energies
688
+ else:
689
+ ch_e, ch_pos, _ = charged_energies
690
+ ch_pid_logits = None
691
+
692
+ if len(pids_neutral):
693
+ ne_e, ne_pos, ne_pid_logits, ne_ref = neutral_energies
694
+ else:
695
+ ne_e, ne_pos, _ = neutral_energies
696
+ ne_pid_logits = None
697
+
698
+ pred_energy[charged_idx.flatten()] = ch_e if len(charged_idx) else pred_energy[charged_idx.flatten()]
699
+ pred_energy[neutral_idx.flatten()] = ne_e if len(neutral_idx) else pred_energy[neutral_idx.flatten()]
700
+
701
+ if ch_pid_logits is not None and len(charged_idx):
702
+ ch_labels = np.array(pids_charged)[np.argmax(ch_pid_logits.cpu().detach().numpy(), axis=1)]
703
+ pred_pid[charged_idx.flatten()] = torch.tensor(ch_labels).long().to(dev)
704
+ if ne_pid_logits is not None and len(neutral_idx):
705
+ ne_labels = np.array(pids_neutral)[np.argmax(ne_pid_logits.cpu().detach().numpy(), axis=1)]
706
+ pred_pid[neutral_idx.flatten()] = torch.tensor(ne_labels).long().to(dev)
707
+
708
+ pred_energy[pred_energy < 0] = 0.0
709
+
710
+ # Direction
711
+ if len(charged_idx):
712
+ pred_pos[charged_idx.flatten()] = ch_pos.float().to(dev)
713
+ if len(neutral_idx):
714
+ pred_pos[neutral_idx.flatten()] = ne_pos.float().to(dev)
715
+
716
+ # Build enriched output DataFrame
717
+ n_matched = len(graphs_matched)
718
+ rows = []
719
+ for k in range(n_clusters):
720
+ is_fake = k >= n_matched
721
+ pid_cls = int(pred_pid[k].item())
722
+ rows.append({
723
+ "cluster_id": k + 1,
724
+ "corrected_energy": round(pred_energy[k].item(), 4),
725
+ "raw_energy": round(sum_e[k].item(), 4),
726
+ "pid_class": pid_cls,
727
+ "pid_label": _PID_LABELS.get(pid_cls, str(pid_cls)),
728
+ "px": round(pred_pos[k, 0].item(), 4),
729
+ "py": round(pred_pos[k, 1].item(), 4),
730
+ "pz": round(pred_pos[k, 2].item(), 4),
731
+ "is_charged": bool(k in charged_idx),
732
+ "is_fake": is_fake,
733
+ })
734
+
735
+ return pd.DataFrame(rows)
src/layers/clustering.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Clustering algorithms for particle-flow reconstruction.
2
+
3
+ Adapted from densitypeakclustering (https://github.com/lanbing510/DensityPeakCluster).
4
+ """
5
+ import torch
6
+ import numpy as np
7
+ from torch_scatter import scatter_add
8
+ import densitypeakclustering as dc
9
+
10
+
11
+ def local_density_energy(D, d_c, energies, normalize=False):
12
+ D_cuttoff = D < d_c
13
+ rho = np.zeros((D.shape[0],))
14
+ for s in range(len(rho)):
15
+ rho[s] = np.sum(energies[D_cuttoff[s, :]] * np.exp(-(D[s, D_cuttoff[s, :]] / d_c) ** 2))
16
+ if normalize:
17
+ rho = rho / np.max(rho)
18
+ return rho
19
+
20
+
21
+ def DPC_custom_CLD(X, g, device):
22
+ d_c = 0.1
23
+ rho_min = 0.05
24
+ delta_min = 0.4
25
+ D = dc.distance_matrix(X.detach().cpu())
26
+ rho = local_density_energy(D, d_c, g.ndata["e_hits"].view(-1).cpu().numpy())
27
+ delta, nearest = dc.distance_to_larger_density(D, rho)
28
+ centers = dc.cluster_centers(rho, delta, rho_min=rho_min, delta_min=delta_min)
29
+ ids = dc.assign_cluster_id(rho, nearest, centers)
30
+ core_ids = np.full(len(X), -1)
31
+ D[np.isnan(D)] = 0
32
+ for indx, c in enumerate(centers):
33
+ idx = np.where((ids == indx) & (D[:, c] < 0.5))[0]
34
+ core_ids[idx] = indx
35
+ labels = torch.Tensor(core_ids) + 1
36
+ return labels.long().to(device)
37
+
38
+
39
+ def remove_bad_tracks_from_cluster(g, labels_hdb):
40
+ mask_hit_type_t1 = g.ndata["hit_type"] == 2
41
+ mask_hit_type_t2 = g.ndata["hit_type"] == 1
42
+ mask_hit_type_t4 = g.ndata["hit_type"] == 4
43
+ labels_hdb_corrected_tracks = labels_hdb.clone()
44
+ labels_changed_tracks = 0.0 * (labels_hdb.clone())
45
+ for i in range(0, torch.max(labels_hdb) + 1):
46
+ mask_labels_i = labels_hdb == i
47
+ if torch.sum(mask_hit_type_t2[mask_labels_i]) > 0 and i > 0:
48
+ e_cluster = torch.sum(g.ndata["e_hits"][mask_labels_i])
49
+ p_track = g.ndata["p_hits"][mask_labels_i * mask_hit_type_t2]
50
+ number_of_hits_muon = torch.sum(mask_labels_i * mask_hit_type_t4)
51
+ diffs = torch.abs(e_cluster - p_track) / p_track
52
+ diffs = diffs.view(-1)
53
+ sigma_4 = 4 * 0.5 / torch.sqrt(p_track).view(-1)
54
+ bad_diffs = diffs > sigma_4
55
+ bad_tracks = bad_diffs * (number_of_hits_muon < 1)
56
+ cluster_t2_nodes = torch.nonzero(mask_labels_i & mask_hit_type_t2).view(-1)
57
+ bad_tracks_nodes = cluster_t2_nodes[bad_tracks]
58
+ labels_hdb_corrected_tracks[bad_tracks_nodes] = 0
59
+ if torch.sum(bad_tracks_nodes) > 0:
60
+ labels_changed_tracks[mask_labels_i] = 1
61
+ return labels_hdb_corrected_tracks, labels_changed_tracks
62
+
63
+
64
+ def remove_labels_of_double_showers(labels, g):
65
+ is_track_per_shower = scatter_add(1 * (g.ndata["hit_type"] == 1), labels).int()
66
+ e_hits_sum = scatter_add(g.ndata["e_hits"].view(-1), labels.view(-1).long()).int()
67
+ mask_tracks = g.ndata["hit_type"] == 1
68
+ for i, label_i in enumerate(torch.unique(labels)):
69
+ if is_track_per_shower[label_i] == 2:
70
+ if label_i > 0:
71
+ sum_pred_2 = e_hits_sum[label_i]
72
+ mask_labels_i = labels == label_i
73
+ mask_label_i_and_is_track = mask_labels_i * mask_tracks
74
+ tracks_E = g.ndata['h'][:, -1][mask_label_i_and_is_track]
75
+ chi_tracks = g.ndata['chi_squared_tracks'][mask_label_i_and_is_track]
76
+ ind_min_E = torch.argmax(torch.abs(tracks_E - sum_pred_2))
77
+ ind_min_chi = torch.argmax(chi_tracks)
78
+ mask_hit_type_t1 = g.ndata["hit_type"][mask_labels_i] == 2
79
+ mask_hit_type_t2 = g.ndata["hit_type"][mask_labels_i] == 1
80
+ mask_all = mask_hit_type_t1
81
+ index_sorted = torch.argsort(g.ndata["radial_distance"][mask_labels_i][mask_hit_type_t1])
82
+ mask_sorted_ind = index_sorted < 10
83
+ mean_pos_cluster = torch.mean(
84
+ g.ndata["pos_hits_xyz"][mask_labels_i][mask_all][mask_sorted_ind], dim=0
85
+ )
86
+ pos_track = g.ndata["pos_hits_xyz"][mask_labels_i][mask_hit_type_t2]
87
+ distance_track_cluster = torch.norm(pos_track - mean_pos_cluster, dim=1) / 1000
88
+ ind_max_dtc = torch.argmax(distance_track_cluster)
89
+ if torch.min(distance_track_cluster) < 0.4:
90
+ ind_min = ind_max_dtc
91
+ elif ind_min_E == ind_min_chi:
92
+ ind_min = ind_min_E
93
+ elif torch.max(chi_tracks - torch.min(chi_tracks)) < 2:
94
+ ind_min = ind_min_E
95
+ else:
96
+ ind_min = ind_min_chi
97
+ ind_change = torch.argwhere(mask_label_i_and_is_track)[ind_min]
98
+ labels[ind_change] = 0
99
+ return labels
src/layers/inference_oc.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file includes code adapted from:
3
+
4
+ densitypeakclustering
5
+ https://github.com/lanbing510/DensityPeakCluster
6
+
7
+ The original implementation has been modified and integrated into this project.
8
+ Please refer to the original repository for authorship, documentation,
9
+ and license information.
10
+ """
11
+ import dgl
12
+ import torch
13
+ import pandas as pd
14
+ import numpy as np
15
+ import wandb
16
+
17
+ from src.layers.clustering import (
18
+ local_density_energy,
19
+ DPC_custom_CLD,
20
+ remove_bad_tracks_from_cluster,
21
+ remove_labels_of_double_showers,
22
+ )
23
+ from src.layers.shower_matching import (
24
+ CachedIndexList,
25
+ get_labels_pandora,
26
+ obtain_intersection_matrix,
27
+ obtain_union_matrix,
28
+ obtain_intersection_values,
29
+ match_showers,
30
+ )
31
+ from src.layers.shower_dataframe import (
32
+ get_correction_per_shower,
33
+ distance_to_true_cluster_of_track,
34
+ distance_to_cluster_track,
35
+ generate_showers_data_frame,
36
+ )
37
+
38
+ # Re-export everything so existing callers (utils_training, Gatr_pf_e_noise, …)
39
+ # that do `from src.layers.inference_oc import X` continue to work unchanged.
40
+ __all__ = [
41
+ "local_density_energy",
42
+ "DPC_custom_CLD",
43
+ "remove_bad_tracks_from_cluster",
44
+ "remove_labels_of_double_showers",
45
+ "CachedIndexList",
46
+ "get_labels_pandora",
47
+ "obtain_intersection_matrix",
48
+ "obtain_union_matrix",
49
+ "obtain_intersection_values",
50
+ "match_showers",
51
+ "get_correction_per_shower",
52
+ "distance_to_true_cluster_of_track",
53
+ "distance_to_cluster_track",
54
+ "generate_showers_data_frame",
55
+ "log_efficiency",
56
+ "store_at_batch_end",
57
+ "create_and_store_graph_output",
58
+ ]
59
+
60
+
61
+ def log_efficiency(df, pandora=False, clustering=False):
62
+ mask = ~np.isnan(df["reco_showers_E"])
63
+ eff = np.sum(~np.isnan(df["pred_showers_E"][mask].values)) / len(
64
+ df["pred_showers_E"][mask].values
65
+ )
66
+ if pandora:
67
+ wandb.log({"efficiency validation pandora": eff})
68
+ elif clustering:
69
+ wandb.log({"efficiency validation clustering": eff})
70
+ else:
71
+ wandb.log({"efficiency validation": eff})
72
+
73
+
74
+ def _make_save_path(path_save, local_rank, step, epoch, suffix=""):
75
+ return path_save + str(local_rank) + "_" + str(step) + "_" + str(epoch) + suffix + ".pt"
76
+
77
+
78
+ def store_at_batch_end(
79
+ path_save,
80
+ df_batch1,
81
+ df_batch_pandora,
82
+ local_rank=0,
83
+ step=0,
84
+ epoch=None,
85
+ predict=False,
86
+ store=False,
87
+ pandora_available=False,
88
+ ):
89
+ path_save_ = _make_save_path(path_save, local_rank, step, epoch)
90
+ if store and predict:
91
+ df_batch1.to_pickle(path_save_)
92
+ if predict and pandora_available:
93
+ path_save_pandora = _make_save_path(path_save, local_rank, step, epoch, "_pandora")
94
+ if store and predict:
95
+ df_batch_pandora.to_pickle(path_save_pandora)
96
+ log_efficiency(df_batch1)
97
+ if predict and pandora_available:
98
+ log_efficiency(df_batch_pandora, pandora=True)
99
+
100
+
101
+ def create_and_store_graph_output(
102
+ batch_g,
103
+ model_output,
104
+ y,
105
+ local_rank,
106
+ step,
107
+ epoch,
108
+ path_save,
109
+ store=False,
110
+ predict=False,
111
+ e_corr=None,
112
+ ec_x=None,
113
+ store_epoch=False,
114
+ total_number_events=0,
115
+ pred_pos=None,
116
+ pred_ref_pt=None,
117
+ use_gt_clusters=False,
118
+ pred_pid=None,
119
+ number_of_fakes=None,
120
+ extra_features=None,
121
+ fakes_labels=None,
122
+ pandora_available=False,
123
+ truth_tracks=False,
124
+ ):
125
+ number_of_showers_total = 0
126
+ number_of_showers_total1 = 0
127
+ number_of_fake_showers_total1 = 0
128
+ batch_g.ndata["coords"] = model_output[:, 0:3]
129
+ batch_g.ndata["beta"] = model_output[:, 3]
130
+ if e_corr is None:
131
+ batch_g.ndata["correction"] = model_output[:, 4]
132
+ graphs = dgl.unbatch(batch_g)
133
+ batch_id = y.batch_number.view(-1)
134
+ df_list1 = []
135
+ df_list_pandora = []
136
+ for i in range(0, len(graphs)):
137
+ mask = batch_id == i
138
+ dic = {}
139
+ dic["graph"] = graphs[i]
140
+ y1 = y.copy()
141
+ y1.mask(mask)
142
+ dic["part_true"] = y1
143
+ X = dic["graph"].ndata["coords"]
144
+ labels_clusters_removed_tracks = torch.zeros(
145
+ dic["graph"].num_nodes(), device=model_output.device
146
+ )
147
+ if use_gt_clusters:
148
+ labels_hdb = dic["graph"].ndata["particle_number"].type(torch.int64)
149
+ else:
150
+ labels_hdb = DPC_custom_CLD(X, dic["graph"], model_output.device)
151
+ if not truth_tracks:
152
+ labels_hdb, labels_clusters_removed_tracks = remove_bad_tracks_from_cluster(
153
+ dic["graph"], labels_hdb
154
+ )
155
+ if predict and pandora_available:
156
+ labels_pandora = get_labels_pandora(dic, model_output.device)
157
+ particle_ids = torch.unique(dic["graph"].ndata["particle_number"])
158
+
159
+ shower_p_unique_hdb, row_ind_hdb, col_ind_hdb, i_m_w_hdb, iou_m = match_showers(
160
+ labels_hdb,
161
+ dic,
162
+ particle_ids,
163
+ model_output,
164
+ local_rank,
165
+ i,
166
+ path_save,
167
+ hdbscan=True,
168
+ )
169
+ if predict and pandora_available:
170
+ (
171
+ shower_p_unique_pandora,
172
+ row_ind_pandora,
173
+ col_ind_pandora,
174
+ i_m_w_pandora,
175
+ iou_m_pandora,
176
+ ) = match_showers(
177
+ labels_pandora,
178
+ dic,
179
+ particle_ids,
180
+ model_output,
181
+ local_rank,
182
+ i,
183
+ path_save,
184
+ pandora=True,
185
+ )
186
+
187
+ if len(shower_p_unique_hdb) > 1:
188
+ df_event1, number_of_showers_total1, number_of_fake_showers_total1 = generate_showers_data_frame(
189
+ labels_hdb,
190
+ dic,
191
+ shower_p_unique_hdb,
192
+ particle_ids,
193
+ row_ind_hdb,
194
+ col_ind_hdb,
195
+ i_m_w_hdb,
196
+ e_corr=e_corr,
197
+ number_of_showers_total=number_of_showers_total1,
198
+ step=step,
199
+ number_in_batch=total_number_events,
200
+ ec_x=ec_x,
201
+ pred_pos=pred_pos,
202
+ pred_ref_pt=pred_ref_pt,
203
+ pred_pid=pred_pid,
204
+ number_of_fakes=number_of_fakes,
205
+ number_of_fake_showers_total=number_of_fake_showers_total1,
206
+ extra_features=extra_features,
207
+ labels_clusters_removed_tracks=labels_clusters_removed_tracks,
208
+ )
209
+ if len(df_event1) > 1:
210
+ df_list1.append(df_event1)
211
+ if predict and pandora_available:
212
+ df_event_pandora = generate_showers_data_frame(
213
+ labels_pandora,
214
+ dic,
215
+ shower_p_unique_pandora,
216
+ particle_ids,
217
+ row_ind_pandora,
218
+ col_ind_pandora,
219
+ i_m_w_pandora,
220
+ pandora=True,
221
+ step=step,
222
+ number_in_batch=total_number_events,
223
+ )
224
+ if df_event_pandora is not None and type(df_event_pandora) is not tuple:
225
+ df_list_pandora.append(df_event_pandora)
226
+ else:
227
+ print("Not appending to df_list_pandora")
228
+ total_number_events = total_number_events + 1
229
+
230
+ df_batch1 = pd.concat(df_list1)
231
+ if predict and pandora_available:
232
+ df_batch_pandora = pd.concat(df_list_pandora)
233
+ else:
234
+ df_batch = []
235
+ df_batch_pandora = []
236
+ if store:
237
+ store_at_batch_end(
238
+ path_save,
239
+ df_batch1,
240
+ df_batch_pandora,
241
+ local_rank,
242
+ step,
243
+ epoch,
244
+ predict=predict,
245
+ store=store_epoch,
246
+ pandora_available=pandora_available,
247
+ )
248
+ if predict:
249
+ return df_batch_pandora, df_batch1, total_number_events
250
+ else:
251
+ return df_batch1
src/layers/object_cond.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The loss implementation in this file is adapted from the HGCalML repository:
3
+
4
+ Repository: https://github.com/jkiesele/HGCalML
5
+ File: modules/lossLayers.py
6
+
7
+ Original author: Jan Kieseler
8
+ License: See the original repository for license details.
9
+
10
+ The implementation has been modified and integrated into this project.
11
+ """
12
+
13
+ from typing import Tuple, Union
14
+ import numpy as np
15
+ import torch
16
+ from torch_scatter import scatter_max, scatter_add, scatter_mean
17
+ import dgl
18
+
19
+ def safe_index(arr, index):
20
+ # One-hot index (or zero if it's not in the array)
21
+ if index not in arr:
22
+ return 0
23
+ else:
24
+ return arr.index(index) + 1
25
+
26
+
27
+ def assert_no_nans(x):
28
+ """
29
+ Raises AssertionError if there is a nan in the tensor
30
+ """
31
+ if torch.isnan(x).any():
32
+ print(x)
33
+ assert not torch.isnan(x).any()
34
+
35
+
36
+ def calc_LV_Lbeta(
37
+ original_coords,
38
+ g,
39
+ y,
40
+ distance_threshold,
41
+ energy_correction,
42
+ beta: torch.Tensor,
43
+ cluster_space_coords: torch.Tensor, # Predicted by model
44
+ cluster_index_per_event: torch.Tensor, # Truth hit->cluster index
45
+ batch: torch.Tensor,
46
+ predicted_pid=None, # predicted PID embeddings - will be aggregated by summing up the clusters and applying the post_pid_pool_module MLP afterwards
47
+ # From here on just parameters
48
+ qmin: float = 0.1,
49
+ s_B: float = 1.0,
50
+ noise_cluster_index: int = 0, # cluster_index entries with this value are noise/noise
51
+ frac_combinations=0, # fraction of the all possible pairs to be used for the clustering loss
52
+ use_average_cc_pos=0.0,
53
+ loss_type="hgcalimplementation",
54
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], dict]:
55
+ """
56
+ Calculates the L_V and L_beta object condensation losses.
57
+ Concepts:
58
+ - A hit belongs to exactly one cluster (cluster_index_per_event is (n_hits,)),
59
+ and to exactly one event (batch is (n_hits,))
60
+ - A cluster index of `noise_cluster_index` means the cluster is a noise cluster.
61
+ There is typically one noise cluster per event. Any hit in a noise cluster
62
+ is a 'noise hit'. A hit in an object is called a 'signal hit' for lack of a
63
+ better term.
64
+ - An 'object' is a cluster that is *not* a noise cluster.
65
+ beta_stabilizing: Choices are ['paper', 'clip', 'soft_q_scaling']:
66
+ paper: beta is sigmoid(model_output), q = beta.arctanh()**2 + qmin
67
+ clip: beta is clipped to 1-1e-4, q = beta.arctanh()**2 + qmin
68
+ soft_q_scaling: beta is sigmoid(model_output), q = (clip(beta)/1.002).arctanh()**2 + qmin
69
+ huberize_norm_for_V_attractive: Huberizes the norms when used in the attractive potential
70
+ beta_term_option: Choices are ['paper', 'short-range-potential']:
71
+ Choosing 'short-range-potential' introduces a short range potential around high
72
+ beta points, acting like V_attractive.
73
+ Note this function has modifications w.r.t. the implementation in 2002.03605:
74
+ - The norms for V_repulsive are now Gaussian (instead of linear hinge)
75
+ """
76
+ # remove dummy rows added for dataloader #TODO think of better way to do this
77
+ device = beta.device
78
+ if torch.isnan(beta).any():
79
+ print("There are nans in beta! L198", len(beta[torch.isnan(beta)]))
80
+
81
+ beta = torch.nan_to_num(beta, nan=0.0)
82
+ assert_no_nans(beta)
83
+ # ________________________________
84
+
85
+ # Calculate a bunch of needed counts and indices locally
86
+
87
+ # cluster_index: unique index over events
88
+ # E.g. cluster_index_per_event=[ 0, 0, 1, 2, 0, 0, 1], batch=[0, 0, 0, 0, 1, 1, 1]
89
+ # -> cluster_index=[ 0, 0, 1, 2, 3, 3, 4 ]
90
+ cluster_index, n_clusters_per_event = batch_cluster_indices(
91
+ cluster_index_per_event, batch
92
+ )
93
+ n_clusters = n_clusters_per_event.sum()
94
+ n_hits, cluster_space_dim = cluster_space_coords.size()
95
+ batch_size = batch.max() + 1
96
+ n_hits_per_event = scatter_count(batch)
97
+
98
+ # Index of cluster -> event (n_clusters,)
99
+ batch_cluster = scatter_counts_to_indices(n_clusters_per_event)
100
+
101
+ # Per-hit boolean, indicating whether hit is sig or noise
102
+ is_noise = cluster_index_per_event == noise_cluster_index
103
+ is_sig = ~is_noise
104
+ n_hits_sig = is_sig.sum()
105
+ n_sig_hits_per_event = scatter_count(batch[is_sig])
106
+
107
+ # Per-cluster boolean, indicating whether cluster is an object or noise
108
+ is_object = scatter_max(is_sig.long(), cluster_index)[0].bool()
109
+ is_noise_cluster = ~is_object
110
+
111
+
112
+ if noise_cluster_index != 0:
113
+ raise NotImplementedError
114
+ object_index_per_event = cluster_index_per_event[is_sig] - 1
115
+ object_index, n_objects_per_event = batch_cluster_indices(
116
+ object_index_per_event, batch[is_sig]
117
+ )
118
+ n_hits_per_object = scatter_count(object_index)
119
+ # print("n_hits_per_object", n_hits_per_object)
120
+ batch_object = batch_cluster[is_object]
121
+ n_objects = is_object.sum()
122
+
123
+ assert object_index.size() == (n_hits_sig,)
124
+ assert is_object.size() == (n_clusters,)
125
+ assert torch.all(n_hits_per_object > 0)
126
+ assert object_index.max() + 1 == n_objects
127
+
128
+ # ________________________________
129
+ # L_V term
130
+
131
+ # Calculate q
132
+ q = (beta.clip(0.0, 1 - 1e-4).arctanh() / 1.01) ** 2 + qmin
133
+ assert_no_nans(q)
134
+ assert q.device == device
135
+ assert q.size() == (n_hits,)
136
+
137
+ # Calculate q_alpha, the max q per object, and the indices of said maxima
138
+ # assert hit_energies.shape == q.shape
139
+ # q_alpha, index_alpha = scatter_max(hit_energies[is_sig], object_index)
140
+ q_alpha, index_alpha = scatter_max(q[is_sig], object_index)
141
+ assert q_alpha.size() == (n_objects,)
142
+
143
+ # Get the cluster space coordinates and betas for these maxima hits too
144
+ x_alpha = cluster_space_coords[is_sig][index_alpha]
145
+ x_alpha_original = original_coords[is_sig][index_alpha]
146
+ if use_average_cc_pos > 0:
147
+ x_alpha_sum = scatter_add(
148
+ q[is_sig].view(-1, 1).repeat(1, 3) * cluster_space_coords[is_sig],
149
+ object_index,
150
+ dim=0,
151
+ ) # * beta[is_sig].view(-1, 1).repeat(1, 3)
152
+ qbeta_alpha_sum = scatter_add(q[is_sig], object_index) + 1e-9 # * beta[is_sig]
153
+ div_fac = 1 / qbeta_alpha_sum
154
+ div_fac = torch.nan_to_num(div_fac, nan=0)
155
+ x_alpha_mean = torch.mul(x_alpha_sum, div_fac.view(-1, 1).repeat(1, 3))
156
+ x_alpha = use_average_cc_pos * x_alpha_mean + (1 - use_average_cc_pos) * x_alpha
157
+
158
+ beta_alpha = beta[is_sig][index_alpha]
159
+ assert x_alpha.size() == (n_objects, cluster_space_dim)
160
+ assert beta_alpha.size() == (n_objects,)
161
+
162
+
163
+ # Connectivity matrix from hit (row) -> cluster (column)
164
+ # Index to matrix, e.g.:
165
+ # [1, 3, 1, 0] --> [
166
+ # [0, 1, 0, 0],
167
+ # [0, 0, 0, 1],
168
+ # [0, 1, 0, 0],
169
+ # [1, 0, 0, 0]
170
+ # ]
171
+ M = torch.nn.functional.one_hot(cluster_index).long()
172
+
173
+ # Anti-connectivity matrix; be sure not to connect hits to clusters in different events!
174
+ M_inv = get_inter_event_norms_mask(batch, n_clusters_per_event) - M
175
+
176
+ # Throw away noise cluster columns; we never need them
177
+ M = M[:, is_object]
178
+ M_inv = M_inv[:, is_object]
179
+ assert M.size() == (n_hits, n_objects)
180
+ assert M_inv.size() == (n_hits, n_objects)
181
+
182
+ # Calculate all norms
183
+ # Warning: Should not be used without a mask!
184
+ # Contains norms between hits and objects from different events
185
+ # (n_hits, 1, cluster_space_dim) - (1, n_objects, cluster_space_dim)
186
+ # gives (n_hits, n_objects, cluster_space_dim)
187
+ norms = (cluster_space_coords.unsqueeze(1) - x_alpha.unsqueeze(0)).norm(dim=-1)
188
+ assert norms.size() == (n_hits, n_objects)
189
+ L_clusters = torch.tensor(0.0).to(device)
190
+ if frac_combinations != 0:
191
+ L_clusters = L_clusters_calc(
192
+ batch, cluster_space_coords, cluster_index, frac_combinations, q
193
+ )
194
+
195
+ # -------
196
+ # Attractive potential term
197
+ # First get all the relevant norms: We only want norms of signal hits
198
+ # w.r.t. the object they belong to, i.e. no noise hits and no noise clusters.
199
+ # First select all norms of all signal hits w.r.t. all objects, mask out later
200
+
201
+ N_k = torch.sum(M, dim=0) # number of hits per object
202
+ norms = torch.sum(
203
+ torch.square(cluster_space_coords.unsqueeze(1) - x_alpha.unsqueeze(0)),
204
+ dim=-1,
205
+ ) # take the norm squared
206
+ norms_att = norms[is_sig]
207
+ #att func as in line 159 of object condensation
208
+
209
+ norms_att = torch.log(
210
+ torch.exp(torch.Tensor([1]).to(norms_att.device)) * norms_att / 2 + 1
211
+ )
212
+
213
+ assert norms_att.size() == (n_hits_sig, n_objects)
214
+
215
+ # Now apply the mask to keep only norms of signal hits w.r.t. to the object
216
+ # they belong to
217
+ norms_att *= M[is_sig]
218
+
219
+ # Sum over hits, then sum per event, then divide by n_hits_per_event, then sum over events
220
+
221
+ V_attractive = (q[is_sig]).unsqueeze(-1) * q_alpha.unsqueeze(0) * norms_att
222
+ V_attractive = V_attractive.sum(dim=0) # K objects
223
+ V_attractive = V_attractive.view(-1) / (N_k.view(-1) + 1e-3)
224
+ L_V_attractive = torch.mean(V_attractive)
225
+
226
+
227
+ norms_rep = torch.relu(1. - torch.sqrt(norms + 1e-6))* M_inv
228
+
229
+
230
+ # (n_sig_hits, 1) * (1, n_objects) * (n_sig_hits, n_objects)
231
+ V_repulsive = q.unsqueeze(1) * q_alpha.unsqueeze(0) * norms_rep
232
+
233
+ # No need to apply a V = max(0, V); by construction V>=0
234
+ assert V_repulsive.size() == (n_hits, n_objects)
235
+
236
+ # Sum over hits, then sum per event, then divide by n_hits_per_event, then sum up events
237
+ nope = n_objects_per_event - 1
238
+ nope[nope == 0] = 1
239
+
240
+ L_V_repulsive = V_repulsive.sum(dim=0)
241
+ number_of_repulsive_terms_per_object = torch.sum(M_inv, dim=0)
242
+ L_V_repulsive = L_V_repulsive.view(
243
+ -1
244
+ ) / number_of_repulsive_terms_per_object.view(-1)
245
+ L_V_repulsive = torch.mean(L_V_repulsive)
246
+ L_V_repulsive2 = L_V_repulsive
247
+
248
+ L_V = (
249
+ L_V_attractive
250
+ + L_V_repulsive
251
+
252
+ )
253
+
254
+
255
+
256
+
257
+ n_noise_hits_per_event = scatter_count(batch[is_noise])
258
+ n_noise_hits_per_event[n_noise_hits_per_event == 0] = 1
259
+ L_beta_noise = (
260
+ s_B
261
+ * (
262
+ (scatter_add(beta[is_noise], batch[is_noise])) / n_noise_hits_per_event
263
+ ).sum()
264
+ )
265
+
266
+ # L_beta signal term
267
+
268
+ beta_per_object_c = scatter_add(beta[is_sig], object_index)
269
+ beta_alpha = beta[is_sig][index_alpha]
270
+ # hit_type_mask = (g.ndata["hit_type"]==1)*(g.ndata["particle_number"]>0)
271
+ # beta_alpha_track = beta[is_sig*hit_type_mask]
272
+ L_beta_sig = torch.mean(
273
+ 1 - beta_alpha + 1 - torch.clip(beta_per_object_c, 0, 1)
274
+ )
275
+
276
+ L_beta_noise = L_beta_noise / 4
277
+
278
+
279
+ L_beta = L_beta_noise + L_beta_sig
280
+
281
+ L_alpha_coordinates = torch.mean(torch.norm(x_alpha_original - x_alpha, p=2, dim=1))
282
+
283
+
284
+ L_exp = L_beta
285
+ if (loss_type == "hgcalimplementation") or (loss_type == "vrepweighted") or (loss_type == "baseline"):
286
+ return (
287
+ L_V,
288
+ L_beta,
289
+ L_beta_sig,
290
+ L_beta_noise,
291
+ 0,
292
+ 0,
293
+ 0,
294
+ None,
295
+ None,
296
+ 0,
297
+ L_clusters,
298
+ 0,
299
+ L_V_attractive,
300
+ L_V_repulsive,
301
+ L_alpha_coordinates,
302
+ L_exp,
303
+ norms_rep,
304
+ norms_att,
305
+ L_V_repulsive2,
306
+ 0
307
+ )
308
+
309
+
310
+ def object_condensation_loss2(
311
+ batch,
312
+ pred,
313
+ pred_2,
314
+ y,
315
+ q_min=0.1,
316
+ use_average_cc_pos=0.0,
317
+ output_dim=4,
318
+ clust_space_norm="none",
319
+ ):
320
+ """
321
+
322
+ :param batch:
323
+ :param pred:
324
+ :param y:
325
+ :param return_resolution: If True, it will only output resolution data to plot for regression (only used for evaluation...)
326
+ :param clust_loss_only: If True, it will only add the clustering terms to the loss
327
+ :return:
328
+ """
329
+ _, S = pred.shape
330
+
331
+ clust_space_dim = 3
332
+
333
+
334
+ bj = torch.sigmoid(torch.reshape(pred[:, clust_space_dim], [-1, 1])) # 3: betas
335
+ # print("bj", bj)
336
+ original_coords = batch.ndata["h"][:, 0:clust_space_dim]
337
+ distance_threshold = 0
338
+ energy_correction = pred_2
339
+ xj = pred[:, 0:clust_space_dim] # xj: cluster space coords
340
+ if clust_space_norm == "twonorm":
341
+ xj = torch.nn.functional.normalize(xj, dim=1) # 0, 1, 2: cluster space coords
342
+ elif clust_space_norm == "tanh":
343
+ xj = torch.tanh(xj)
344
+ elif clust_space_norm == "none":
345
+ pass
346
+ else:
347
+ raise NotImplementedError
348
+
349
+ dev = batch.device
350
+ clustering_index_l = batch.ndata["particle_number"]
351
+
352
+ len_batch = len(batch.batch_num_nodes())
353
+ batch_numbers = torch.repeat_interleave(
354
+ torch.arange(0, len_batch).to(dev), batch.batch_num_nodes()
355
+ ).to(dev)
356
+
357
+ a = calc_LV_Lbeta(
358
+ original_coords,
359
+ batch,
360
+ y,
361
+ distance_threshold,
362
+ energy_correction,
363
+ beta=bj.view(-1),
364
+ cluster_space_coords=xj, # Predicted by model
365
+ cluster_index_per_event=clustering_index_l.view(
366
+ -1
367
+ ).long(), # Truth hit->cluster index
368
+ batch=batch_numbers.long(),
369
+ qmin=q_min,
370
+ use_average_cc_pos=use_average_cc_pos,
371
+ )
372
+
373
+
374
+ loss = 1 * a[0] + a[1]
375
+
376
+ return loss, a
377
+
378
+ def formatted_loss_components_string(components: dict) -> str:
379
+ """
380
+ Formats the components returned by calc_LV_Lbeta
381
+ """
382
+ total_loss = components["L_V"] + components["L_beta"]
383
+ fractions = {k: v / total_loss for k, v in components.items()}
384
+ fkey = lambda key: f"{components[key]:+.4f} ({100.*fractions[key]:.1f}%)"
385
+ s = (
386
+ " L_V = {L_V}"
387
+ "\n L_V_attractive = {L_V_attractive}"
388
+ "\n L_V_repulsive = {L_V_repulsive}"
389
+ "\n L_beta = {L_beta}"
390
+ "\n L_beta_noise = {L_beta_noise}"
391
+ "\n L_beta_sig = {L_beta_sig}".format(
392
+ L=total_loss, **{k: fkey(k) for k in components}
393
+ )
394
+ )
395
+ if "L_beta_norms_term" in components:
396
+ s += (
397
+ "\n L_beta_norms_term = {L_beta_norms_term}"
398
+ "\n L_beta_logbeta_term = {L_beta_logbeta_term}".format(
399
+ **{k: fkey(k) for k in components}
400
+ )
401
+ )
402
+ if "L_noise_filter" in components:
403
+ s += f'\n L_noise_filter = {fkey("L_noise_filter")}'
404
+ return s
405
+
406
+
407
+ def huber(d, delta):
408
+ """
409
+ See: https://en.wikipedia.org/wiki/Huber_loss#Definition
410
+ Multiplied by 2 w.r.t Wikipedia version (aligning with Jan's definition)
411
+ """
412
+ return torch.where(
413
+ torch.abs(d) <= delta, d**2, 2.0 * delta * (torch.abs(d) - delta)
414
+ )
415
+
416
+
417
+ def batch_cluster_indices(
418
+ cluster_id: torch.Tensor, batch: torch.Tensor
419
+ ) -> Tuple[torch.LongTensor, torch.LongTensor]:
420
+ """
421
+ Turns cluster indices per event to an index in the whole batch
422
+ Example:
423
+ cluster_id = torch.LongTensor([0, 0, 1, 1, 2, 0, 0, 1, 1, 1, 0, 0, 1])
424
+ batch = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2])
425
+ -->
426
+ offset = torch.LongTensor([0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 5, 5, 5])
427
+ output = torch.LongTensor([0, 0, 1, 1, 2, 3, 3, 4, 4, 4, 5, 5, 6])
428
+ """
429
+ device = cluster_id.device
430
+ assert cluster_id.device == batch.device
431
+ # Count the number of clusters per entry in the batch
432
+ n_clusters_per_event = scatter_max(cluster_id, batch, dim=-1)[0] + 1
433
+ # Offsets are then a cumulative sum
434
+ offset_values_nozero = n_clusters_per_event[:-1].cumsum(dim=-1)
435
+ # Prefix a zero
436
+ offset_values = torch.cat((torch.zeros(1, device=device), offset_values_nozero))
437
+ # Fill it per hit
438
+ offset = torch.gather(offset_values, 0, batch).long()
439
+ return offset + cluster_id, n_clusters_per_event
440
+
441
+
442
+ def get_clustering(betas: torch.Tensor, X: torch.Tensor, tbeta=0.1, td=1.0):
443
+ """
444
+ Returns a clustering of hits -> cluster_index, based on the GravNet model
445
+ output (predicted betas and cluster space coordinates) and the clustering
446
+ parameters tbeta and td.
447
+ Takes torch.Tensors as input.
448
+ """
449
+ n_points = betas.size(0)
450
+ select_condpoints = betas > tbeta
451
+ # Get indices passing the threshold
452
+ indices_condpoints = select_condpoints.nonzero()
453
+ # Order them by decreasing beta value
454
+ indices_condpoints = indices_condpoints[(-betas[select_condpoints]).argsort()]
455
+ # Assign points to condensation points
456
+ # Only assign previously unassigned points (no overwriting)
457
+ # Points unassigned at the end are bkg (-1)
458
+ unassigned = torch.arange(n_points)
459
+ clustering = -1 * torch.ones(n_points, dtype=torch.long).to(betas.device)
460
+ for index_condpoint in indices_condpoints:
461
+ d = torch.norm(X[unassigned] - X[index_condpoint][0], dim=-1)
462
+ assigned_to_this_condpoint = unassigned[d < td]
463
+ clustering[assigned_to_this_condpoint] = index_condpoint[0]
464
+ unassigned = unassigned[~(d < td)]
465
+ return clustering
466
+
467
+
468
+ def scatter_count(input: torch.Tensor):
469
+ """
470
+ Returns ordered counts over an index array
471
+ Example:
472
+ >>> scatter_count(torch.Tensor([0, 0, 0, 1, 1, 2, 2])) # input
473
+ >>> [3, 2, 2]
474
+ Index assumptions work like in torch_scatter, so:
475
+ >>> scatter_count(torch.Tensor([1, 1, 1, 2, 2, 4, 4]))
476
+ >>> tensor([0, 3, 2, 0, 2])
477
+ """
478
+ return scatter_add(torch.ones_like(input, dtype=torch.long), input.long())
479
+
480
+
481
+ def scatter_counts_to_indices(input: torch.LongTensor) -> torch.LongTensor:
482
+ """
483
+ Converts counts to indices. This is the inverse operation of scatter_count
484
+ Example:
485
+ input: [3, 2, 2]
486
+ output: [0, 0, 0, 1, 1, 2, 2]
487
+ """
488
+ return torch.repeat_interleave(
489
+ torch.arange(input.size(0), device=input.device), input
490
+ ).long()
491
+
492
+
493
+ def get_inter_event_norms_mask(
494
+ batch: torch.LongTensor, nclusters_per_event: torch.LongTensor
495
+ ):
496
+ """
497
+ Creates mask of (nhits x nclusters) that is only 1 if hit i is in the same event as cluster j
498
+ Example:
499
+ cluster_id_per_event = torch.LongTensor([0, 0, 1, 1, 2, 0, 0, 1, 1, 1, 0, 0, 1])
500
+ batch = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2])
501
+ Should return:
502
+ torch.LongTensor([
503
+ [1, 1, 1, 0, 0, 0, 0],
504
+ [1, 1, 1, 0, 0, 0, 0],
505
+ [1, 1, 1, 0, 0, 0, 0],
506
+ [1, 1, 1, 0, 0, 0, 0],
507
+ [1, 1, 1, 0, 0, 0, 0],
508
+ [0, 0, 0, 1, 1, 0, 0],
509
+ [0, 0, 0, 1, 1, 0, 0],
510
+ [0, 0, 0, 1, 1, 0, 0],
511
+ [0, 0, 0, 1, 1, 0, 0],
512
+ [0, 0, 0, 1, 1, 0, 0],
513
+ [0, 0, 0, 0, 0, 1, 1],
514
+ [0, 0, 0, 0, 0, 1, 1],
515
+ [0, 0, 0, 0, 0, 1, 1],
516
+ ])
517
+ """
518
+ device = batch.device
519
+ # Following the example:
520
+ # Expand batch to the following (nhits x nevents) matrix (little hacky, boolean mask -> long):
521
+ # [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
522
+ # [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0],
523
+ # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]]
524
+ batch_expanded_as_ones = (
525
+ batch
526
+ == torch.arange(batch.max() + 1, dtype=torch.long, device=device).unsqueeze(-1)
527
+ ).long()
528
+ # Then repeat_interleave it to expand it to nclusters rows, and transpose to get (nhits x nclusters)
529
+ return batch_expanded_as_ones.repeat_interleave(nclusters_per_event, dim=0).T
530
+
531
+
532
+ def isin(ar1, ar2):
533
+ """To be replaced by torch.isin for newer releases of torch"""
534
+ return (ar1[..., None] == ar2).any(-1)
535
+
536
+
537
+ def L_clusters_calc(batch, cluster_space_coords, cluster_index, frac_combinations, q):
538
+ number_of_pairs = 0
539
+ for batch_id in batch.unique():
540
+ # do all possible pairs...
541
+ bmask = batch == batch_id
542
+ clust_space_filt = cluster_space_coords[bmask]
543
+ pos_pairs_all = []
544
+ neg_pairs_all = []
545
+ if len(cluster_index[bmask].unique()) <= 1:
546
+ continue
547
+ L_clusters = torch.tensor(0.0).to(q.device)
548
+ for cluster in cluster_index[bmask].unique():
549
+ coords_pos = clust_space_filt[cluster_index[bmask] == cluster]
550
+ coords_neg = clust_space_filt[cluster_index[bmask] != cluster]
551
+ if len(coords_neg) == 0:
552
+ continue
553
+ clust_idx = cluster_index[bmask] == cluster
554
+ # all_ones = torch.ones_like((clust_idx, clust_idx))
555
+ # pos_pairs = [[i, j] for i in range(len(coords_pos)) for j in range (len(coords_pos)) if i < j]
556
+ total_num = (len(coords_pos) ** 2) / 2
557
+ num = int(frac_combinations * total_num)
558
+ pos_pairs = []
559
+ for i in range(num):
560
+ pos_pairs.append(
561
+ [
562
+ np.random.randint(len(coords_pos)),
563
+ np.random.randint(len(coords_pos)),
564
+ ]
565
+ )
566
+ neg_pairs = []
567
+ for i in range(len(pos_pairs)):
568
+ neg_pairs.append(
569
+ [
570
+ np.random.randint(len(coords_pos)),
571
+ np.random.randint(len(coords_neg)),
572
+ ]
573
+ )
574
+ pos_pairs_all += pos_pairs
575
+ neg_pairs_all += neg_pairs
576
+ pos_pairs = torch.tensor(pos_pairs_all)
577
+ neg_pairs = torch.tensor(neg_pairs_all)
578
+ assert pos_pairs.shape == neg_pairs.shape
579
+ if len(pos_pairs) == 0:
580
+ continue
581
+ cluster_space_coords_filtered = cluster_space_coords[bmask]
582
+ qs_filtered = q[bmask]
583
+ pos_norms = (
584
+ cluster_space_coords_filtered[pos_pairs[:, 0]]
585
+ - cluster_space_coords_filtered[pos_pairs[:, 1]]
586
+ ).norm(dim=-1)
587
+
588
+ neg_norms = (
589
+ cluster_space_coords_filtered[neg_pairs[:, 0]]
590
+ - cluster_space_coords_filtered[neg_pairs[:, 1]]
591
+ ).norm(dim=-1)
592
+ q_pos = qs_filtered[pos_pairs[:, 0]]
593
+ q_neg = qs_filtered[neg_pairs[:, 0]]
594
+ q_s = torch.cat([q_pos, q_neg])
595
+ norms_pos = torch.cat([pos_norms, neg_norms])
596
+ ys = torch.cat([torch.ones_like(pos_norms), -torch.ones_like(neg_norms)])
597
+ L_clusters += torch.sum(
598
+ q_s * torch.nn.HingeEmbeddingLoss(reduce=None)(norms_pos, ys)
599
+ )
600
+ number_of_pairs += norms_pos.shape[0]
601
+ if number_of_pairs > 0:
602
+ L_clusters = L_clusters / number_of_pairs
603
+
604
+ return L_clusters
605
+
606
+
607
+
608
+
609
+
src/layers/regression/loss_regression.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import numpy as np
4
+
5
+
6
+
7
+ def obtain_PID_charged(dic,pid_true_matched, pids_charged, args, pid_conversion_dict):
8
+ charged_PID_pred = dic["charged_PID_pred"]
9
+ charged_PID_true = np.array(pid_true_matched)[dic["charged_idx"].cpu().tolist()]
10
+ # one-hot encoded
11
+ charged_PID_true_onehot = torch.zeros(
12
+ len(charged_PID_true), len(pids_charged)
13
+ ).to(charged_PID_pred.device)
14
+ mask_charged = torch.ones(len(charged_PID_true))
15
+ pids_charged_arr = np.array(pids_charged)
16
+ for i, pid in enumerate(charged_PID_true):
17
+ if pid not in pid_conversion_dict:
18
+ print("Unknown PID", pid)
19
+ true_idx = pid_conversion_dict.get(pid, 3)
20
+ col = np.where(pids_charged_arr == true_idx)[0]
21
+ if len(col) == 0:
22
+ mask_charged[i] = 0
23
+ else:
24
+ charged_PID_true_onehot[i, col[0]] = 1
25
+ return charged_PID_pred, charged_PID_true_onehot, mask_charged
26
+
27
+
28
+
29
+
30
+
31
+
32
+ def obtain_PID_neutral(dic,pid_true_matched,pids_neutral, args, pid_conversion_dict):
33
+ neutral_PID_pred = dic["neutral_PID_pred"]
34
+ neutral_idx = dic["neutrals_idx"]
35
+ neutral_PID_true = np.array(pid_true_matched)[neutral_idx.cpu()]
36
+ if type(neutral_PID_true) == np.float64:
37
+ neutral_PID_true = [neutral_PID_true]
38
+ # One-hot encoded
39
+ neutral_PID_true_onehot = torch.zeros(
40
+ len(neutral_PID_true), len(pids_neutral)
41
+ ).to(neutral_PID_pred.device)
42
+ mask_neutral = torch.ones(len(neutral_PID_true))
43
+
44
+ # convert from true PID to int list PID (4-class encoding)
45
+ pids_neutral_arr = np.array(pids_neutral)
46
+ for i, pid in enumerate(neutral_PID_true):
47
+ if pid not in pid_conversion_dict:
48
+ print("Unknown PID", pid)
49
+ true_idx = pid_conversion_dict.get(pid, 3)
50
+ col = np.where(pids_neutral_arr == true_idx)[0]
51
+ if len(col) == 0:
52
+ mask_neutral[i] = 0
53
+ else:
54
+ neutral_PID_true_onehot[i, col[0]] = 1
55
+ neutral_PID_true_onehot = neutral_PID_true_onehot.to(neutral_idx.device)
56
+ return neutral_PID_pred, neutral_PID_true_onehot, mask_neutral
57
+
58
+
59
+
src/layers/shower_dataframe.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DataFrame construction and shower-level helpers for particle-flow reconstruction."""
2
+ import torch
3
+ import pandas as pd
4
+ from torch_scatter import scatter_add, scatter_mean, scatter_max
5
+
6
+ from src.layers.clustering import remove_labels_of_double_showers
7
+ from src.layers.shower_matching import obtain_intersection_values
8
+
9
+
10
+ # ---------------------------------------------------------------------------
11
+ # Small tensor helpers
12
+ # ---------------------------------------------------------------------------
13
+
14
+ def nan_like(t):
15
+ return torch.zeros_like(t) * torch.nan
16
+
17
+
18
+ def nan_tensor(*size, device):
19
+ return torch.zeros(*size, device=device) * torch.nan
20
+
21
+
22
+ def _window(tensor, start, count):
23
+ return tensor[start : start + count]
24
+
25
+
26
+ def _compute_pandora_momentum(labels, g):
27
+ """Scatter-mean the pandora momentum/reference-point node features per cluster.
28
+
29
+ Returns (pxyz, ref_pt, pandora_pid, calc_pandora_momentum). All three
30
+ tensor outputs are None when the graph does not carry 'pandora_momentum'.
31
+ """
32
+ calc_pandora_momentum = "pandora_momentum" in g.ndata
33
+ if not calc_pandora_momentum:
34
+ return None, None, None, False
35
+ px = scatter_mean(g.ndata["pandora_momentum"][:, 0], labels)
36
+ py = scatter_mean(g.ndata["pandora_momentum"][:, 1], labels)
37
+ pz = scatter_mean(g.ndata["pandora_momentum"][:, 2], labels)
38
+ ref_pt_px = scatter_mean(g.ndata["pandora_reference_point"][:, 0], labels)
39
+ ref_pt_py = scatter_mean(g.ndata["pandora_reference_point"][:, 1], labels)
40
+ ref_pt_pz = scatter_mean(g.ndata["pandora_reference_point"][:, 2], labels)
41
+ pandora_pid = scatter_mean(g.ndata["pandora_pid"], labels)
42
+ ref_pt = torch.stack((ref_pt_px, ref_pt_py, ref_pt_pz), dim=1)
43
+ pxyz = torch.stack((px, py, pz), dim=1)
44
+ return pxyz, ref_pt, pandora_pid, True
45
+
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # Per-shower correction
49
+ # ---------------------------------------------------------------------------
50
+
51
+ def get_correction_per_shower(labels, dic):
52
+ unique_labels = torch.unique(labels)
53
+ list_corr = []
54
+ for ii, pred_label in enumerate(unique_labels):
55
+ if ii == 0:
56
+ if pred_label != 0:
57
+ list_corr.append(dic["graph"].ndata["correction"][0].view(-1) * 0)
58
+ mask = labels == pred_label
59
+ corrections_E_label = dic["graph"].ndata["correction"][mask]
60
+ betas_label_indmax = torch.argmax(dic["graph"].ndata["beta"][mask])
61
+ list_corr.append(corrections_E_label[betas_label_indmax].view(-1))
62
+ corrections = torch.cat(list_corr, dim=0)
63
+ return corrections
64
+
65
+
66
+ # ---------------------------------------------------------------------------
67
+ # Track–cluster distance helpers
68
+ # ---------------------------------------------------------------------------
69
+
70
+ def distance_to_true_cluster_of_track(dic, labels):
71
+ g = dic["graph"]
72
+ mask_hit_type_t2 = g.ndata["hit_type"] == 1
73
+ if torch.sum(labels.unique() == 0) == 0:
74
+ distances = torch.zeros(len(labels.unique()) + 1).float().to(labels.device)
75
+ number_of_tracks = torch.zeros(len(labels.unique()) + 1).int()
76
+ else:
77
+ distances = torch.zeros(len(labels.unique())).float().to(labels.device)
78
+ number_of_tracks = torch.zeros(len(labels.unique())).int()
79
+ for i, label in enumerate(labels.unique()):
80
+ mask_labels_i = labels == label
81
+ mask = mask_labels_i * mask_hit_type_t2
82
+ if mask.sum() == 0:
83
+ continue
84
+ pos_track = g.ndata["pos_hits_xyz"][mask][0]
85
+ if pos_track.shape[0] == 0:
86
+ continue
87
+ true_part_idx_track = g.ndata["particle_number"][mask_labels_i * mask_hit_type_t2][0].int()
88
+ mask_labels_i_true = g.ndata["particle_number"] == true_part_idx_track
89
+ mean_pos_cluster_true = torch.mean(
90
+ g.ndata["pos_hits_xyz"][mask_labels_i_true], dim=0
91
+ )
92
+ number_of_tracks[label] = torch.sum(mask_labels_i_true * mask_hit_type_t2)
93
+ distances[label] = torch.norm(mean_pos_cluster_true - pos_track) / 3300
94
+ return distances, number_of_tracks
95
+
96
+
97
+ def distance_to_cluster_track(dic, is_track_in_MC):
98
+ g = dic["graph"]
99
+ mask_hit_type_t1 = g.ndata["hit_type"] == 2
100
+ mask_hit_type_t2 = g.ndata["hit_type"] == 1
101
+ pos_track = g.ndata["pos_hits_xyz"][mask_hit_type_t2]
102
+ particle_track = g.ndata["particle_number"][mask_hit_type_t2]
103
+ if len(particle_track) > 0:
104
+ mean_pos_cluster_all = []
105
+ for i in particle_track:
106
+ if i == 0:
107
+ mean_pos_cluster_all.append(torch.zeros((1, 3)).view(-1, 3).to(particle_track.device))
108
+ else:
109
+ mask_labels_i = g.ndata["particle_number"] == i
110
+ mean_pos_cluster = torch.mean(g.ndata["pos_hits_xyz"][mask_labels_i * mask_hit_type_t1], dim=0)
111
+ mean_pos_cluster_all.append(mean_pos_cluster.view(-1, 3))
112
+ mean_pos_cluster_all = torch.cat(mean_pos_cluster_all, dim=0)
113
+ distance_track_cluster = torch.norm(mean_pos_cluster_all - pos_track, dim=1) / 1000
114
+ if len(particle_track) > len(torch.unique(particle_track)):
115
+ distance_track_cluster_unique = []
116
+ for i in torch.unique(particle_track):
117
+ mask_tracks = particle_track == i
118
+ distance_track_cluster_unique.append(torch.min(distance_track_cluster[mask_tracks]).view(-1))
119
+ distance_track_cluster_unique = torch.cat(distance_track_cluster_unique, dim=0)
120
+ unique_particle_track = torch.unique(particle_track)
121
+ else:
122
+ distance_track_cluster_unique = distance_track_cluster
123
+ unique_particle_track = particle_track
124
+ distance_to_cluster_all = is_track_in_MC.clone().float()
125
+ distance_to_cluster_all[unique_particle_track.long()] = distance_track_cluster_unique
126
+ return distance_to_cluster_all
127
+ else:
128
+ return is_track_in_MC.clone().float()
129
+
130
+
131
+ # ---------------------------------------------------------------------------
132
+ # Main DataFrame builder
133
+ # ---------------------------------------------------------------------------
134
+
135
+ def generate_showers_data_frame(
136
+ labels,
137
+ dic,
138
+ shower_p_unique,
139
+ particle_ids,
140
+ row_ind,
141
+ col_ind,
142
+ i_m_w,
143
+ pandora=False,
144
+ e_corr=None,
145
+ number_of_showers_total=None,
146
+ step=0,
147
+ number_in_batch=0,
148
+ ec_x=None,
149
+ pred_pos=None,
150
+ pred_pid=None,
151
+ pred_ref_pt=None,
152
+ number_of_fake_showers_total=None,
153
+ number_of_fakes=None,
154
+ extra_features=None,
155
+ labels_clusters_removed_tracks=None,
156
+ ):
157
+ e_pred_showers = scatter_add(dic["graph"].ndata["e_hits"].view(-1), labels)
158
+ e_pred_showers_ecal = scatter_add(1 * (dic["graph"].ndata["hit_type"].view(-1) == 2), labels)
159
+ e_pred_showers_hcal = scatter_add(1 * (dic["graph"].ndata["hit_type"].view(-1) == 3), labels)
160
+ if not pandora:
161
+ removed_tracks = scatter_add(1 * labels_clusters_removed_tracks, labels)
162
+ if pandora:
163
+ e_pred_showers_cali = scatter_mean(
164
+ dic["graph"].ndata["pandora_pfo_energy"].view(-1), labels
165
+ )
166
+ e_pred_showers_pfo = scatter_mean(
167
+ dic["graph"].ndata["pandora_pfo_energy"].view(-1), labels
168
+ )
169
+ pxyz_pred_pfo, ref_pt_pred_pfo, pandora_pid, calc_pandora_momentum = \
170
+ _compute_pandora_momentum(labels, dic["graph"])
171
+ else:
172
+ if e_corr is None:
173
+ corrections_per_shower = get_correction_per_shower(labels, dic)
174
+ e_pred_showers_cali = e_pred_showers * corrections_per_shower
175
+ else:
176
+ corrections_per_shower = e_corr.view(-1)
177
+ if number_of_fakes > 0:
178
+ corrections_per_shower_fakes = corrections_per_shower[-number_of_fakes:]
179
+ corrections_per_shower = corrections_per_shower[:-number_of_fakes]
180
+
181
+ e_reco_showers = scatter_add(
182
+ dic["graph"].ndata["e_hits"].view(-1),
183
+ dic["graph"].ndata["particle_number"].long(),
184
+ )
185
+ e_label_showers = scatter_max(
186
+ labels.view(-1),
187
+ dic["graph"].ndata["particle_number"].long(),
188
+ )[0]
189
+ is_track_in_MC = scatter_add(
190
+ 1 * (dic["graph"].ndata["hit_type"].view(-1) == 1),
191
+ dic["graph"].ndata["particle_number"].long(),
192
+ )
193
+ track_chi = scatter_add(
194
+ 1 * (dic["graph"].ndata["chi_squared_tracks"].view(-1) == 1),
195
+ dic["graph"].ndata["particle_number"].long(),
196
+ )
197
+ distance_to_cluster_all = distance_to_cluster_track(dic, is_track_in_MC)
198
+ distances, number_of_tracks = distance_to_true_cluster_of_track(dic, labels)
199
+
200
+ row_ind = torch.Tensor(row_ind).to(e_pred_showers.device).long()
201
+ col_ind = torch.Tensor(col_ind).to(e_pred_showers.device).long()
202
+
203
+ if torch.sum(particle_ids == 0) > 0:
204
+ row_ind_ = row_ind - 1
205
+ else:
206
+ row_ind_ = row_ind
207
+
208
+ pred_showers = shower_p_unique
209
+ energy_t = (
210
+ dic["part_true"].E_corrected.view(-1).to(e_pred_showers.device)
211
+ ).float()
212
+ gen_status = (
213
+ dic["part_true"].gen_status.view(-1).to(e_pred_showers.device)
214
+ ).float()
215
+ vertex = dic["part_true"].vertex.to(e_pred_showers.device)
216
+ pos_t = dic["part_true"].coord.to(e_pred_showers.device)
217
+ pid_t = dic["part_true"].pid.to(e_pred_showers.device)
218
+ if not pandora:
219
+ labels = remove_labels_of_double_showers(labels, dic["graph"])
220
+ is_track_per_shower = scatter_add(1 * (dic["graph"].ndata["hit_type"] == 1), labels).int()
221
+ is_track = torch.zeros(energy_t.shape).to(e_pred_showers.device)
222
+
223
+ index_matches = col_ind + 1
224
+ index_matches = index_matches.to(e_pred_showers.device).long()
225
+
226
+ dev = e_pred_showers.device
227
+ matched_es = nan_like(energy_t)
228
+ matched_ECAL = nan_like(energy_t)
229
+ matched_HCAL = nan_like(energy_t)
230
+ matched_positions = nan_tensor(energy_t.shape[0], 3, device=dev)
231
+ matched_ref_pt = nan_tensor(energy_t.shape[0], 3, device=dev)
232
+ matched_pid = nan_like(energy_t).long()
233
+ matched_positions_pfo = nan_tensor(energy_t.shape[0], 3, device=dev)
234
+ matched_pandora_pid = nan_tensor(energy_t.shape[0], device=dev)
235
+ matched_ref_pts_pfo = nan_tensor(energy_t.shape[0], 3, device=dev)
236
+ matched_extra_features = torch.zeros((energy_t.shape[0], 7)) * torch.nan
237
+
238
+ matched_es[row_ind_] = e_pred_showers[index_matches]
239
+ matched_ECAL[row_ind_] = 1.0 * e_pred_showers_ecal[index_matches]
240
+ matched_HCAL[row_ind_] = 1.0 * e_pred_showers_hcal[index_matches]
241
+
242
+ if pandora:
243
+ matched_es_cali = matched_es.clone()
244
+ matched_es_cali[row_ind_] = e_pred_showers_cali[index_matches]
245
+ matched_es_cali_pfo = matched_es.clone()
246
+ matched_es_cali_pfo[row_ind_] = e_pred_showers_pfo[index_matches]
247
+ matched_pandora_pid[row_ind_] = pandora_pid[index_matches]
248
+ if calc_pandora_momentum:
249
+ matched_positions_pfo[row_ind_] = pxyz_pred_pfo[index_matches]
250
+ matched_ref_pts_pfo[row_ind_] = ref_pt_pred_pfo[index_matches]
251
+ is_track[row_ind_] = is_track_per_shower[index_matches].float()
252
+ else:
253
+ if e_corr is None:
254
+ matched_es_cali = matched_es.clone()
255
+ matched_es_cali[row_ind_] = e_pred_showers_cali[index_matches]
256
+ calibration_per_shower = matched_es.clone()
257
+ calibration_per_shower[row_ind_] = corrections_per_shower[index_matches]
258
+ cluster_removed_tracks = matched_es.clone()
259
+ else:
260
+ matched_es_cali = matched_es.clone()
261
+ number_of_showers = e_pred_showers[index_matches].shape[0]
262
+ matched_es_cali[row_ind_] = _window(
263
+ corrections_per_shower, number_of_showers_total, number_of_showers
264
+ )
265
+ cluster_removed_tracks = matched_es.clone()
266
+ cluster_removed_tracks[row_ind_] = 1.0 * removed_tracks[index_matches]
267
+
268
+ if pred_pos is not None:
269
+ matched_positions[row_ind_] = _window(pred_pos, number_of_showers_total, number_of_showers)
270
+ matched_ref_pt[row_ind_] = _window(pred_ref_pt, number_of_showers_total, number_of_showers)
271
+ matched_pid[row_ind_] = _window(pred_pid, number_of_showers_total, number_of_showers)
272
+ if not pandora:
273
+ matched_extra_features[row_ind_] = torch.tensor(
274
+ _window(extra_features, number_of_showers_total, number_of_showers)
275
+ )
276
+
277
+ calibration_per_shower = matched_es.clone()
278
+ calibration_per_shower[row_ind_] = _window(
279
+ corrections_per_shower, number_of_showers_total, number_of_showers
280
+ )
281
+ number_of_showers_total = number_of_showers_total + number_of_showers
282
+ is_track[row_ind_] = is_track_per_shower[index_matches].float()
283
+
284
+ # match the tracks to the particle
285
+ dic["graph"].ndata["particle_number_u"] = dic["graph"].ndata["particle_number"].clone()
286
+ dic["graph"].ndata["particle_number_u"][dic["graph"].ndata["particle_number_u"] == 0] = 100
287
+ tracks_label = scatter_max(
288
+ (dic["graph"].ndata["hit_type"] == 1) * (dic["graph"].ndata["particle_number_u"]), labels
289
+ )[0].int()
290
+ tracks_label = tracks_label - 1
291
+ tracks_label[tracks_label < 0] = 0
292
+ matched_es_tracks = nan_like(energy_t)
293
+ matched_es_tracks_1 = nan_like(energy_t)
294
+ matched_es_tracks[row_ind_] = row_ind_.float()
295
+ matched_es_tracks_1[row_ind_] = tracks_label[index_matches].float()
296
+ matched_es_tracks_1 = 1.0 * (matched_es_tracks == matched_es_tracks_1)
297
+ matched_es_tracks_1 = matched_es_tracks_1 * is_track
298
+
299
+ intersection_E = nan_like(energy_t)
300
+ if len(col_ind) > 0:
301
+ ie_e = obtain_intersection_values(i_m_w, row_ind, col_ind, dic)
302
+ intersection_E[row_ind_] = ie_e.to(e_pred_showers.device)
303
+ pred_showers[index_matches] = -1
304
+ pred_showers[0] = -1
305
+ mask = pred_showers != -1
306
+ fakes_in_event = mask.sum()
307
+ fake_showers_e = e_pred_showers[mask]
308
+ fake_showers_e_hcal = e_pred_showers_hcal[mask]
309
+ fake_showers_e_ecal = e_pred_showers_ecal[mask]
310
+ number_of_fake_showers = mask.sum()
311
+
312
+ all_labels = labels.unique().to(e_pred_showers.device)
313
+ number_of_fake_showers = mask.sum()
314
+ fakes_labels = torch.where(mask)[0].to(e_pred_showers.device)
315
+ fake_showers_distance_to_cluster = distances[fakes_labels.cpu()]
316
+ fake_showers_num_tracks = number_of_tracks[fakes_labels.cpu()]
317
+
318
+ if e_corr is None or pandora:
319
+ fake_showers_e_cali = e_pred_showers_cali[mask]
320
+ else:
321
+ fakes_positions = pred_pos[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total + number_of_fake_showers]
322
+ fake_showers_e_cali = e_corr[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total + number_of_fake_showers]
323
+ fakes_pid_pred = pred_pid[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total + number_of_fake_showers]
324
+ fake_showers_e_reco = e_reco_showers[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total + number_of_fake_showers]
325
+ fakes_positions = fakes_positions.to(e_pred_showers.device)
326
+ fakes_extra_features = extra_features[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total + number_of_fake_showers]
327
+ fake_showers_e_cali = fake_showers_e_cali.to(e_pred_showers.device)
328
+ fakes_pid_pred = fakes_pid_pred.to(e_pred_showers.device)
329
+ fake_showers_e_reco = fake_showers_e_reco.to(e_pred_showers.device)
330
+
331
+ if pandora:
332
+ fake_pandora_pid = (torch.zeros((fake_showers_e.shape[0], 3)) * torch.nan).to(dev)
333
+ fake_pandora_pid = pandora_pid[mask]
334
+ if calc_pandora_momentum:
335
+ fake_positions_pfo = nan_tensor(fake_showers_e.shape[0], 3, device=dev)
336
+ fake_positions_pfo = pxyz_pred_pfo[mask]
337
+ fakes_positions_ref = nan_tensor(fake_showers_e.shape[0], 3, device=dev)
338
+ fakes_positions_ref = ref_pt_pred_pfo[mask]
339
+ if not pandora:
340
+ if e_corr is None:
341
+ fake_showers_e_cali_factor = corrections_per_shower[mask]
342
+ else:
343
+ fake_showers_e_cali_factor = fake_showers_e_cali
344
+ fake_showers_showers_e_truw = nan_tensor(fake_showers_e.shape[0], device=dev)
345
+ fake_showers_vertex = nan_tensor(fake_showers_e.shape[0], 3, device=dev)
346
+ fakes_is_track = (torch.zeros((fake_showers_e.shape[0])) * torch.nan).to(dev)
347
+ fakes_is_track = is_track_per_shower[mask]
348
+ fakes_positions_t = nan_tensor(fake_showers_e.shape[0], 3, device=dev)
349
+ if not pandora:
350
+ number_of_fake_showers_total = number_of_fake_showers_total + number_of_fake_showers
351
+
352
+ energy_t = torch.cat((energy_t, fake_showers_showers_e_truw), dim=0)
353
+ gen_status = torch.cat((gen_status, fake_showers_showers_e_truw), dim=0)
354
+ vertex = torch.cat((vertex, fake_showers_vertex), dim=0)
355
+ pid_t = torch.cat((pid_t.view(-1), fake_showers_showers_e_truw), dim=0)
356
+ pos_t = torch.cat((pos_t, fakes_positions_t), dim=0)
357
+ e_reco = torch.cat((e_reco_showers[1:], fake_showers_showers_e_truw), dim=0)
358
+ e_labels = torch.cat((e_label_showers[1:], 0 * fake_showers_showers_e_truw), dim=0)
359
+ is_track_in_MC = torch.cat((is_track_in_MC[1:], fake_showers_num_tracks.to(e_reco.device)), dim=0)
360
+ track_chi = torch.cat((track_chi[1:], fake_showers_num_tracks.to(e_reco.device)), dim=0)
361
+ distance_to_cluster_MC = torch.cat(
362
+ (distance_to_cluster_all[1:], fake_showers_distance_to_cluster.to(e_reco.device)), dim=0
363
+ )
364
+ e_pred = torch.cat((matched_es, fake_showers_e), dim=0)
365
+ e_pred_ECAL = torch.cat((matched_ECAL, fake_showers_e_ecal), dim=0)
366
+ e_pred_HCAL = torch.cat((matched_HCAL, fake_showers_e_hcal), dim=0)
367
+ e_pred_cali = torch.cat((matched_es_cali, fake_showers_e_cali), dim=0)
368
+ if pred_pos is not None:
369
+ e_pred_pos = torch.cat((matched_positions, fakes_positions), dim=0)
370
+ e_pred_pid = torch.cat((matched_pid, fakes_pid_pred), dim=0)
371
+ e_pred_ref_pt = torch.cat((matched_ref_pt, fakes_positions), dim=0)
372
+ extra_features_all = torch.cat(
373
+ (matched_extra_features, torch.tensor(fakes_extra_features)), dim=0
374
+ )
375
+ if pandora:
376
+ e_pred_cali_pfo = torch.cat((matched_es_cali_pfo, fake_showers_e_cali), dim=0)
377
+ positions_pfo = torch.cat((matched_positions_pfo, fake_positions_pfo), dim=0)
378
+ pandora_pid = torch.cat((matched_pandora_pid, fake_pandora_pid), dim=0)
379
+ ref_pts_pfo = torch.cat((matched_ref_pts_pfo, fakes_positions_ref), dim=0)
380
+ else:
381
+ cluster_removed_tracks = torch.cat((cluster_removed_tracks, 0 * fake_showers_e_cali), dim=0)
382
+ if not pandora:
383
+ calibration_factor = torch.cat((calibration_per_shower, fake_showers_e_cali_factor), dim=0)
384
+
385
+ e_pred_t = torch.cat(
386
+ (intersection_E, nan_like(fake_showers_e)),
387
+ dim=0,
388
+ )
389
+ is_track = torch.cat((is_track, fakes_is_track.to(is_track.device)), dim=0)
390
+ matched_es_tracks_1 = torch.cat(
391
+ (matched_es_tracks_1, 0 * fakes_is_track.to(is_track.device)), dim=0
392
+ )
393
+
394
+ # Build shared base dict, then update with pandora- or non-pandora-specific keys
395
+ d = {
396
+ "true_showers_E": energy_t.detach().cpu(),
397
+ "reco_showers_E": e_reco.detach().cpu(),
398
+ "pred_showers_E": e_pred.detach().cpu(),
399
+ "e_pred_and_truth": e_pred_t.detach().cpu(),
400
+ "pid": pid_t.detach().cpu(),
401
+ "step": torch.ones_like(energy_t.detach().cpu()) * step,
402
+ "number_batch": torch.ones_like(energy_t.detach().cpu()) * number_in_batch,
403
+ "is_track_in_cluster": is_track.detach().cpu(),
404
+ "is_track_correct": matched_es_tracks_1.detach().cpu(),
405
+ "is_track_in_MC": is_track_in_MC.detach().cpu(),
406
+ "track_chi": track_chi.detach().cpu(),
407
+ "distance_to_cluster_MC": distance_to_cluster_MC.detach().cpu(),
408
+ "vertex": vertex.detach().cpu().tolist(),
409
+ "ECAL_hits": e_pred_ECAL.detach().cpu(),
410
+ "HCAL_hits": e_pred_HCAL.detach().cpu(),
411
+ "gen_status": gen_status.detach().cpu(),
412
+ "labels": e_labels.detach().cpu(),
413
+ }
414
+ if pandora:
415
+ d.update({
416
+ "pandora_calibrated_E": e_pred_cali.detach().cpu(),
417
+ "pandora_calibrated_pfo": e_pred_cali_pfo.detach().cpu(),
418
+ "pandora_calibrated_pos": positions_pfo.detach().cpu().tolist(),
419
+ "pandora_ref_pt": ref_pts_pfo.detach().cpu().tolist(),
420
+ "pandora_pid": pandora_pid.detach().cpu(),
421
+ })
422
+ else:
423
+ d.update({
424
+ "calibration_factor": calibration_factor.detach().cpu(),
425
+ "calibrated_E": e_pred_cali.detach().cpu(),
426
+ "cluster_removed_tracks": cluster_removed_tracks.detach().cpu(),
427
+ })
428
+ if pred_pos is not None:
429
+ d["pred_pos_matched"] = e_pred_pos.detach().cpu().tolist()
430
+ d["pred_pid_matched"] = e_pred_pid.detach().cpu().tolist()
431
+ d["pred_ref_pt_matched"] = e_pred_ref_pt.detach().cpu().tolist()
432
+ d["matched_extra_features"] = extra_features_all.detach().cpu().tolist()
433
+
434
+ d["true_pos"] = pos_t.detach().cpu().tolist()
435
+ df = pd.DataFrame(data=d)
436
+ if number_of_showers_total is None:
437
+ return df
438
+ else:
439
+ return df, number_of_showers_total, number_of_fake_showers_total
440
+ else:
441
+ return [], 0, 0
src/layers/shower_matching.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shower matching utilities for particle-flow reconstruction."""
2
+ import torch
3
+ import numpy as np
4
+ from torch_scatter import scatter_add
5
+ from scipy.optimize import linear_sum_assignment
6
+
7
+
8
+ class CachedIndexList:
9
+ def __init__(self, lst):
10
+ self.lst = lst
11
+ self.cache = {}
12
+
13
+ def index(self, value):
14
+ if value in self.cache:
15
+ return self.cache[value]
16
+ else:
17
+ idx = self.lst.index(value)
18
+ self.cache[value] = idx
19
+ return idx
20
+
21
+
22
+ def get_labels_pandora(dic, device):
23
+ labels_pandora = dic["graph"].ndata["pandora_pfo"].long()
24
+ labels_pandora = labels_pandora + 1
25
+ map_from = list(np.unique(labels_pandora.detach().cpu()))
26
+ map_from = CachedIndexList(map_from)
27
+ cluster_id = map(lambda x: map_from.index(x), labels_pandora.detach().cpu().numpy())
28
+ labels_pandora = torch.Tensor(list(cluster_id)).long().to(device)
29
+ return labels_pandora
30
+
31
+
32
+ def obtain_intersection_matrix(shower_p_unique, particle_ids, labels, dic, e_hits):
33
+ len_pred_showers = len(shower_p_unique)
34
+ intersection_matrix = torch.zeros((len_pred_showers, len(particle_ids))).to(
35
+ shower_p_unique.device
36
+ )
37
+ intersection_matrix_w = torch.zeros((len_pred_showers, len(particle_ids))).to(
38
+ shower_p_unique.device
39
+ )
40
+ for index, id in enumerate(particle_ids):
41
+ counts = torch.zeros_like(labels)
42
+ mask_p = dic["graph"].ndata["particle_number"] == id
43
+ h_hits = e_hits.clone()
44
+ counts[mask_p] = 1
45
+ h_hits[~mask_p] = 0
46
+ intersection_matrix[:, index] = scatter_add(counts, labels)
47
+ intersection_matrix_w[:, index] = scatter_add(h_hits, labels.to(h_hits.device))
48
+ return intersection_matrix, intersection_matrix_w
49
+
50
+
51
+ def obtain_union_matrix(shower_p_unique, particle_ids, labels, dic):
52
+ len_pred_showers = len(shower_p_unique)
53
+ union_matrix = torch.zeros((len_pred_showers, len(particle_ids)))
54
+ for index, id in enumerate(particle_ids):
55
+ counts = torch.zeros_like(labels)
56
+ mask_p = dic["graph"].ndata["particle_number"] == id
57
+ for index_pred, id_pred in enumerate(shower_p_unique):
58
+ mask_pred_p = labels == id_pred
59
+ mask_union = mask_pred_p + mask_p
60
+ union_matrix[index_pred, index] = torch.sum(mask_union)
61
+ return union_matrix
62
+
63
+
64
+ def obtain_intersection_values(intersection_matrix_w, row_ind, col_ind, dic):
65
+ list_intersection_E = []
66
+ particle_ids = torch.unique(dic["graph"].ndata["particle_number"])
67
+ if torch.sum(particle_ids == 0) > 0:
68
+ intersection_matrix_wt = torch.transpose(intersection_matrix_w[1:, 1:], 1, 0)
69
+ row_ind = row_ind - 1
70
+ else:
71
+ intersection_matrix_wt = torch.transpose(intersection_matrix_w[1:, :], 1, 0)
72
+ for i in range(0, len(col_ind)):
73
+ list_intersection_E.append(
74
+ intersection_matrix_wt[row_ind[i], col_ind[i]].view(-1)
75
+ )
76
+ if len(list_intersection_E) > 0:
77
+ return torch.cat(list_intersection_E, dim=0)
78
+ else:
79
+ return 0
80
+
81
+
82
+ def match_showers(
83
+ labels,
84
+ dic,
85
+ particle_ids,
86
+ model_output,
87
+ local_rank,
88
+ i,
89
+ path_save,
90
+ pandora=False,
91
+ hdbscan=False,
92
+ ):
93
+ iou_threshold = 0.25
94
+ shower_p_unique = torch.unique(labels)
95
+ if torch.sum(labels == 0) == 0:
96
+ shower_p_unique = torch.cat(
97
+ (
98
+ torch.Tensor([0]).to(shower_p_unique.device).view(-1),
99
+ shower_p_unique.view(-1),
100
+ ),
101
+ dim=0,
102
+ )
103
+ e_hits = dic["graph"].ndata["e_hits"].view(-1)
104
+ i_m, i_m_w = obtain_intersection_matrix(
105
+ shower_p_unique, particle_ids, labels, dic, e_hits
106
+ )
107
+ i_m = i_m.to(model_output.device)
108
+ i_m_w = i_m_w.to(model_output.device)
109
+ u_m = obtain_union_matrix(shower_p_unique, particle_ids, labels, dic)
110
+ u_m = u_m.to(model_output.device)
111
+ iou_matrix = i_m / u_m
112
+ if torch.sum(particle_ids == 0) > 0:
113
+ iou_matrix_num = (
114
+ torch.transpose(iou_matrix[1:, 1:], 1, 0).clone().detach().cpu().numpy()
115
+ )
116
+ else:
117
+ iou_matrix_num = (
118
+ torch.transpose(iou_matrix[1:, :], 1, 0).clone().detach().cpu().numpy()
119
+ )
120
+ iou_matrix_num[iou_matrix_num < iou_threshold] = 0
121
+ row_ind, col_ind = linear_sum_assignment(-iou_matrix_num)
122
+ mask_matching_matrix = iou_matrix_num[row_ind, col_ind] > 0
123
+ row_ind = row_ind[mask_matching_matrix]
124
+ col_ind = col_ind[mask_matching_matrix]
125
+ if torch.sum(particle_ids == 0) > 0:
126
+ row_ind = row_ind + 1
127
+ return shower_p_unique, row_ind, col_ind, i_m_w, iou_matrix
src/layers/tools_for_regression.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch_scatter import scatter_mean, scatter_sum
4
+
5
+ def pick_lowest_chi_squared(pxpypz, chi_s, batch_idx, xyz_nodes):
6
+ unique_batch = torch.unique(batch_idx)
7
+ p_direction = []
8
+ track_xyz = []
9
+ for i in range(0, len(unique_batch)):
10
+ mask = batch_idx == unique_batch[i]
11
+ if torch.sum(mask) > 1:
12
+ chis = chi_s[mask]
13
+ ind_min = torch.argmin(chis)
14
+ p_direction.append(pxpypz[mask][ind_min].view(-1, 3))
15
+ track_xyz.append(xyz_nodes[mask][ind_min].view(-1, 3))
16
+
17
+ else:
18
+ p_direction.append(pxpypz[mask].view(-1, 3))
19
+ track_xyz.append(xyz_nodes[mask].view(-1, 3))
20
+ return torch.concat(p_direction, dim=0), torch.stack(track_xyz)[:, 0]
21
+
22
+
23
+
24
+ class AverageHitsP(torch.nn.Module):
25
+ # Same layout of the module as the GNN one, but just computes the average of the hits. Try to compare this + ML clustering with Pandora
26
+ def __init__(self, ecal_only=False):
27
+ super(AverageHitsP, self).__init__()
28
+ self.ecal_only = ecal_only
29
+ def predict(self, x_global_features, graphs_new=None, explain=False):
30
+ """
31
+ Forward, named 'predict' for compatibility reasons
32
+ :param x_global_features: Global features of the graphs - to be concatenated to each node feature
33
+ :param graphs_new:
34
+ :return:
35
+ """
36
+ assert graphs_new is not None
37
+ batch_num_nodes = graphs_new.batch_num_nodes() # Num. of hits in each graph
38
+ batch_idx = []
39
+ batch_bounds = []
40
+ if self.ecal_only:
41
+ mask_ecal_only = [] # whether to consider only ECAL or ECAL+HCAL
42
+ for i, n in enumerate(batch_num_nodes):
43
+ batch_idx.extend([i] * n)
44
+ batch_bounds.append(n)
45
+ batch_idx = np.array(batch_idx)
46
+ for i in range(len(np.unique(batch_idx))):
47
+ if self.ecal_only:
48
+ n_ecal_hits = (graphs_new.ndata["h"][batch_idx == i, 5] > 0).sum()
49
+ n_hcal_hits = (graphs_new.ndata["h"][batch_idx == i, 6] > 0).sum()
50
+ for _ in range(batch_num_nodes[i]):
51
+ mask_ecal_only.append((n_ecal_hits / (n_hcal_hits + n_ecal_hits)).item())
52
+ batch_idx = torch.tensor(batch_idx).to(graphs_new.device)
53
+ if self.ecal_only:
54
+ mask_ecal_only = torch.tensor(mask_ecal_only) # round().int().bool().to(graphs_new.device)
55
+ mask_ecal_only = (mask_ecal_only > 0.05).int().bool().to(graphs_new.device)
56
+ #mask_ecal_only=torch.zeros(len(mask_ecal_only)).bool().to(graphs_new.device)
57
+ xyz_hits = graphs_new.ndata["h"][:, :3]
58
+ E_hits = graphs_new.ndata["h"][:, 8]
59
+ if self.ecal_only:
60
+ hcal_hits = graphs_new.ndata["h"][:, 6] > 0
61
+ E_hits[mask_ecal_only & (hcal_hits)] = 0
62
+ weighted_avg_hits = scatter_sum(xyz_hits * E_hits.unsqueeze(1), batch_idx, dim=0)
63
+ E_total = scatter_sum(E_hits, batch_idx, dim=0)
64
+ p_direction = weighted_avg_hits / E_total.unsqueeze(1)
65
+ p_tracks = torch.norm(p_direction, dim=1)
66
+ p_direction = p_direction / torch.norm(p_direction, dim=1).unsqueeze(1)
67
+ # if self.pos_regression:
68
+ return p_tracks, p_direction, weighted_avg_hits / E_total.unsqueeze(1) * 3300 # Reference point
69
+ # return p_tracks
70
+
71
+
72
+
73
+ class PickPAtDCA(torch.nn.Module):
74
+ # Same layout of the module as the GNN one, but just picks the track
75
+ def __init__(self):
76
+ super(PickPAtDCA, self).__init__()
77
+
78
+ def predict(self, x_global_features, graphs_new=None, explain=False):
79
+ """
80
+ Forward, named 'predict' for compatibility reasons
81
+ :param x_global_features: Global features of the graphs - to be concatenated to each node feature
82
+ :param graphs_new:
83
+ :return:
84
+ """
85
+ assert graphs_new is not None
86
+ batch_num_nodes = graphs_new.batch_num_nodes()
87
+ batch_idx = []
88
+ batch_bounds = []
89
+ for i, n in enumerate(batch_num_nodes):
90
+ batch_idx.extend([i] * n)
91
+ batch_bounds.append(n)
92
+ batch_idx = torch.tensor(batch_idx).to(graphs_new.device)
93
+
94
+ ht = graphs_new.ndata["h"][:, 3:7].argmax(dim=1)
95
+ filt = ht == 1 # track
96
+ filt_hits = ((ht == 2) + (ht == 3)).bool()
97
+
98
+ p_direction, p_xyz = pick_lowest_chi_squared(
99
+ graphs_new.ndata["pos_pxpypz_at_vertex"][filt],
100
+ graphs_new.ndata["chi_squared_tracks"][filt],
101
+ batch_idx[filt],
102
+ graphs_new.ndata["h"][filt, :3]
103
+ )
104
+ # Barycenters of clusters of hits
105
+ xyz_hits = graphs_new.ndata["h"][:, :3]
106
+ E_hits = graphs_new.ndata["h"][:, 8]
107
+ weighted_avg_hits = scatter_sum(xyz_hits * E_hits.unsqueeze(1), batch_idx, dim=0)
108
+ E_total = scatter_sum(E_hits, batch_idx, dim=0)
109
+ barycenters = weighted_avg_hits / E_total.unsqueeze(1)
110
+ p_tracks = torch.norm(p_direction, dim=1)
111
+ return p_tracks, p_direction, barycenters - p_xyz
112
+
113
+
114
+
115
+ class ECNetWrapperAvg(torch.nn.Module):
116
+ # use the GNN+NN model for energy correction
117
+ # This one concatenates GNN features to the global features
118
+ def __init__(self):
119
+ super(ECNetWrapperAvg, self).__init__()
120
+ self.AvgHits = AverageHitsP(ecal_only=True)
121
+
122
+ def predict(self, x_global_features, graphs_new=None, explain=False):
123
+ """
124
+ Forward, named 'predict' for compatibility reasons
125
+ :param x_global_features: Global features of the graphs - to be concatenated to each node feature
126
+ :param graphs_new:
127
+ :return:
128
+ """
129
+ _, p_pred, _ = self.AvgHits.predict(x_global_features, graphs_new)
130
+ p_pred = (p_pred / torch.norm(p_pred, dim=1).unsqueeze(1)).clone()
131
+ return None, p_pred, None, None
src/layers/utils_training.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from lightning.pytorch.callbacks import BaseFinetuning
3
+ import torch
4
+ import dgl
5
+ from src.layers.inference_oc import DPC_custom_CLD
6
+ from src.layers.inference_oc import match_showers
7
+ from src.layers.inference_oc import remove_bad_tracks_from_cluster
8
+ class FreezeClustering(BaseFinetuning):
9
+ def __init__(
10
+ self,
11
+ ):
12
+ super().__init__()
13
+
14
+ def freeze_before_training(self, pl_module):
15
+ self.freeze(pl_module.ScaledGooeyBatchNorm2_1)
16
+ self.freeze(pl_module.gatr)
17
+ self.freeze(pl_module.clustering)
18
+ self.freeze(pl_module.beta)
19
+
20
+ print("CLUSTERING HAS BEEN FROOOZEN")
21
+
22
+ def finetune_function(self, pl_module, current_epoch, optimizer):
23
+ print("Not finetunning")
24
+
25
+
26
+
27
+ def obtain_batch_numbers(x, g):
28
+ dev = x.device
29
+ graphs_eval = dgl.unbatch(g)
30
+ number_graphs = len(graphs_eval)
31
+ batch_numbers = []
32
+ for index in range(0, number_graphs):
33
+ gj = graphs_eval[index]
34
+ num_nodes = gj.number_of_nodes()
35
+ batch_numbers.append(index * torch.ones(num_nodes).to(dev))
36
+ # num_nodes = gj.number_of_nodes()
37
+
38
+ batch = torch.cat(batch_numbers, dim=0)
39
+ return batch
40
+
41
+
42
+
43
+ def obtain_clustering_for_matched_showers(
44
+ batch_g, model_output, y_all, local_rank, use_gt_clusters=False, add_fakes=True
45
+ ):
46
+
47
+ graphs_showers_matched = []
48
+ graphs_showers_fakes = []
49
+ true_energy_showers = []
50
+ reco_energy_showers = []
51
+ reco_energy_showers_fakes = []
52
+ energy_true_daughters = []
53
+ y_pids_matched = []
54
+ y_coords_matched = []
55
+ if not use_gt_clusters:
56
+ batch_g.ndata["coords"] = model_output[:, 0:3]
57
+ batch_g.ndata["beta"] = model_output[:, 3]
58
+ graphs = dgl.unbatch(batch_g)
59
+ batch_id = y_all.batch_number
60
+ for i in range(0, len(graphs)):
61
+ mask = batch_id == i
62
+ dic = {}
63
+ dic["graph"] = graphs[i]
64
+ y = y_all.copy()
65
+
66
+ y.mask(mask.flatten())
67
+ dic["part_true"] = y
68
+ if not use_gt_clusters:
69
+ betas = torch.sigmoid(dic["graph"].ndata["beta"])
70
+ X = dic["graph"].ndata["coords"]
71
+
72
+ if use_gt_clusters:
73
+ labels = dic["graph"].ndata["particle_number"].type(torch.int64)
74
+ else:
75
+ labels =DPC_custom_CLD(X, dic["graph"], model_output.device)
76
+
77
+ labels, _ = remove_bad_tracks_from_cluster(dic["graph"], labels)
78
+ particle_ids = torch.unique(dic["graph"].ndata["particle_number"])
79
+ shower_p_unique = torch.unique(labels)
80
+ shower_p_unique, row_ind, col_ind, i_m_w, _ = match_showers(
81
+ labels, dic, particle_ids, model_output, local_rank, i, None
82
+ )
83
+ row_ind = torch.Tensor(row_ind).to(model_output.device).long()
84
+ col_ind = torch.Tensor(col_ind).to(model_output.device).long()
85
+ if torch.sum(particle_ids == 0) > 0:
86
+ row_ind_ = row_ind - 1
87
+ else:
88
+ # if there is no zero then index 0 corresponds to particle 1.
89
+ row_ind_ = row_ind
90
+ index_matches = col_ind + 1
91
+ index_matches = index_matches.to(model_output.device).long()
92
+
93
+ for j, unique_showers_label in enumerate(index_matches):
94
+ if torch.sum(unique_showers_label == index_matches) == 1:
95
+ index_in_matched = torch.argmax(
96
+ (unique_showers_label == index_matches) * 1
97
+ )
98
+ mask = labels == unique_showers_label
99
+ sls_graph = graphs[i].ndata["pos_hits_xyz"][mask][:, 0:3]
100
+ g = dgl.graph(([], []))
101
+ g.add_nodes(sls_graph.shape[0])
102
+ g = g.to(sls_graph.device)
103
+ g.ndata["h"] = graphs[i].ndata["h"][mask]
104
+ if "pos_pxpypz" in graphs[i].ndata:
105
+ g.ndata["pos_pxpypz"] = graphs[i].ndata["pos_pxpypz"][mask]
106
+ if "pos_pxpypz_at_vertex" in graphs[i].ndata:
107
+ g.ndata["pos_pxpypz_at_vertex"] = graphs[i].ndata[
108
+ "pos_pxpypz_at_vertex"
109
+ ][mask]
110
+ g.ndata["chi_squared_tracks"] = graphs[i].ndata["chi_squared_tracks"][mask]
111
+ energy_t = dic["part_true"].E.to(model_output.device)
112
+ energy_t_corr_daughters = dic["part_true"].m.to(
113
+ model_output.device
114
+ )
115
+ true_energy_shower = energy_t[row_ind_[j]]
116
+ y_pids_matched.append(y.pid[row_ind_[j]].item())
117
+ y_coords_matched.append(y.coord[row_ind_[j]].detach().cpu().numpy())
118
+ energy_true_daughters.append(energy_t_corr_daughters[row_ind_[j]])
119
+ reco_energy_shower = torch.sum(graphs[i].ndata["e_hits"][mask])
120
+ graphs_showers_matched.append(g)
121
+ true_energy_showers.append(true_energy_shower.view(-1))
122
+ reco_energy_showers.append(reco_energy_shower.view(-1))
123
+ pred_showers = shower_p_unique
124
+ pred_showers[index_matches] = -1
125
+ pred_showers[
126
+ 0
127
+ ] = (
128
+ -1
129
+ )
130
+ mask_fakes = pred_showers != -1
131
+ fakes_idx = torch.where(mask_fakes)[0]
132
+ if add_fakes:
133
+ for j in fakes_idx:
134
+ mask = labels == j
135
+ sls_graph = graphs[i].ndata["pos_hits_xyz"][mask][:, 0:3]
136
+ g = dgl.graph(([], []))
137
+ g.add_nodes(sls_graph.shape[0])
138
+ g = g.to(sls_graph.device)
139
+
140
+ g.ndata["h"] = graphs[i].ndata["h"][mask]
141
+
142
+ if "pos_pxpypz" in graphs[i].ndata:
143
+ g.ndata["pos_pxpypz"] = graphs[i].ndata["pos_pxpypz"][mask]
144
+ if "pos_pxpypz_at_vertex" in graphs[i].ndata:
145
+ g.ndata["pos_pxpypz_at_vertex"] = graphs[i].ndata[
146
+ "pos_pxpypz_at_vertex"
147
+ ][mask]
148
+ g.ndata["chi_squared_tracks"] = graphs[i].ndata["chi_squared_tracks"][mask]
149
+ graphs_showers_fakes.append(g)
150
+ reco_energy_shower = torch.sum(graphs[i].ndata["e_hits"][mask])
151
+ reco_energy_showers_fakes.append(reco_energy_shower.view(-1))
152
+ graphs_showers_matched = dgl.batch(graphs_showers_matched + graphs_showers_fakes)
153
+ true_energy_showers = torch.cat(true_energy_showers, dim=0)
154
+ reco_energy_showers = torch.cat(reco_energy_showers + reco_energy_showers_fakes, dim=0)
155
+ e_true_corr_daughters = torch.cat(energy_true_daughters, dim=0)
156
+ number_of_fakes = len(reco_energy_showers_fakes)
157
+ return (
158
+ graphs_showers_matched,
159
+ true_energy_showers,
160
+ reco_energy_showers,
161
+ y_pids_matched,
162
+ e_true_corr_daughters,
163
+ y_coords_matched,
164
+ number_of_fakes,
165
+ fakes_idx
166
+ )
src/models/E_correction_module.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import io
4
+ import pickle
5
+
6
+ class Net(nn.Module):
7
+ def __init__(self, in_features=13, out_features=1, return_raw=True):
8
+ super(Net, self).__init__()
9
+ self.out_features = out_features
10
+ self.return_raw = return_raw
11
+ self.model = nn.ModuleList(
12
+ [
13
+ # nn.BatchNorm1d(13),
14
+ nn.Linear(in_features, 64),
15
+ nn.ReLU(),
16
+ nn.Linear(64, 64),
17
+ # nn.BatchNorm1d(64),
18
+ nn.ReLU(),
19
+ nn.Linear(64, 64),
20
+ nn.ReLU(),
21
+ nn.Linear(64, out_features),
22
+ ]
23
+ )
24
+ self.explainer_mode = False
25
+
26
+ def forward(self, x):
27
+ if not isinstance(x, torch.Tensor):
28
+ x = torch.tensor(x)
29
+ for layer in self.model:
30
+ x = layer(x)
31
+ if self.out_features > 1 and not self.return_raw:
32
+ return x[:, 0], x[:, 1:]
33
+ if self.explainer_mode:
34
+ return x.numpy()
35
+ return x
36
+
37
+ def freeze_batchnorm(self):
38
+ for layer in self.model:
39
+ if isinstance(layer, nn.BatchNorm1d):
40
+ layer.eval()
41
+ print("Frozen batchnorm in 1st layer only - ", layer)
42
+ break
43
+
src/models/Gatr_pf_e_noise.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file includes code adapted from:
3
+
4
+ Geometric Algebra Transformer (GATr)
5
+ https://github.com/Qualcomm-AI-research/geometric-algebra-transformer
6
+
7
+ The original implementation is by Qualcomm AI Research. It has been modified
8
+ and integrated into this project for particle-flow reconstruction at the
9
+ CLD detector (FCC-ee). Please refer to the original repository for
10
+ authorship, documentation, and license information.
11
+ """
12
+ import torch
13
+ import torch.nn as nn
14
+ import dgl
15
+ from src.layers.object_cond import object_condensation_loss2
16
+ from src.models.energy_correction_NN import EnergyCorrection
17
+ from src.layers.inference_oc import create_and_store_graph_output
18
+ import lightning as L
19
+ from torch.optim.lr_scheduler import CosineAnnealingLR
20
+ from xformers.ops.fmha import BlockDiagonalMask
21
+ import os
22
+ import wandb
23
+ from gatr import GATr, SelfAttentionConfig, MLPConfig
24
+ from gatr.interface import embed_point, extract_scalar, extract_point, embed_scalar
25
+ from src.utils.logger_wandb import log_losses_wandb
26
+
27
+
28
+ class ExampleWrapper(L.LightningModule):
29
+ def __init__(
30
+ self,
31
+ args,
32
+ dev,
33
+ blocks=10,
34
+ hidden_mv_channels=16,
35
+ hidden_s_channels=64,
36
+ config=None
37
+ ):
38
+ super().__init__()
39
+ self.strict_loading = False
40
+ self.input_dim = 3
41
+ self.output_dim = 4
42
+ self.loss_final = 0
43
+ self.number_b = 0
44
+ self.df_showers = []
45
+ self.df_showers_pandora = []
46
+ self.df_showers_db = []
47
+ self.args = args
48
+ self.dev = dev
49
+ self.config = config
50
+ self.gatr = GATr(
51
+ in_mv_channels=1,
52
+ out_mv_channels=1,
53
+ hidden_mv_channels=hidden_mv_channels,
54
+ in_s_channels=2,
55
+ out_s_channels=1,
56
+ hidden_s_channels=hidden_s_channels,
57
+ num_blocks=blocks,
58
+ attention=SelfAttentionConfig(),
59
+ mlp=MLPConfig(),
60
+ )
61
+ self.ScaledGooeyBatchNorm2_1 = nn.BatchNorm1d(self.input_dim, momentum=0.1)
62
+ self.clustering = nn.Linear(3, self.output_dim - 1, bias=False)
63
+ self.beta = nn.Linear(2, 1)
64
+ if self.args.correction:
65
+ self.energy_correction = EnergyCorrection(self)
66
+ self.ec_model_wrapper_charged = self.energy_correction.model_charged
67
+ self.ec_model_wrapper_neutral = self.energy_correction.model_neutral
68
+ self.pids_neutral = self.energy_correction.pids_neutral
69
+ self.pids_charged = self.energy_correction.pids_charged
70
+ else:
71
+ self.pids_neutral = []
72
+ self.pids_charged = []
73
+
74
+ def forward(self, g, y, step_count, eval="", return_train=False, use_gt_clusters=False):
75
+ if not use_gt_clusters:
76
+ inputs = g.ndata["pos_hits_xyz"].float()
77
+ inputs_scalar = g.ndata["hit_type"].float().view(-1, 1)
78
+ inputs = self.ScaledGooeyBatchNorm2_1(inputs)
79
+ embedded_inputs = embed_point(inputs) + embed_scalar(inputs_scalar)
80
+ embedded_inputs = embedded_inputs.unsqueeze(-2) # (N, 1, 16)
81
+ mask = self.build_attention_mask(g)
82
+ scalars = torch.cat((g.ndata["e_hits"].float(), g.ndata["p_hits"].float()), dim=1)
83
+ embedded_outputs, scalar_outputs = self.gatr(
84
+ embedded_inputs, scalars=scalars, attention_mask=mask
85
+ )
86
+ points = extract_point(embedded_outputs[:, 0, :])
87
+ nodewise_outputs = extract_scalar(embedded_outputs) # (N, 1, 1)
88
+ x_point = points
89
+ x_scalar = torch.cat(
90
+ (nodewise_outputs.view(-1, 1), scalar_outputs.view(-1, 1)), dim=1
91
+ )
92
+ x_cluster_coord = self.clustering(x_point)
93
+ beta = self.beta(x_scalar)
94
+ g.ndata["final_cluster"] = x_cluster_coord
95
+ g.ndata["beta"] = beta.view(-1)
96
+ x = torch.cat((x_cluster_coord, beta.view(-1, 1)), dim=1)
97
+ else:
98
+ x = torch.ones_like(g.ndata["h"][:, 0:4])
99
+
100
+ if self.args.correction:
101
+ result = self.energy_correction.forward_correction(g, x, y, return_train)
102
+ return result
103
+ else:
104
+ pred_energy_corr = torch.ones_like(beta.view(-1, 1))
105
+ return x, pred_energy_corr, 0, 0
106
+
107
+ def build_attention_mask(self, g):
108
+ batch_numbers = obtain_batch_numbers(g)
109
+ return BlockDiagonalMask.from_seqlens(
110
+ torch.bincount(batch_numbers.long()).tolist()
111
+ )
112
+
113
+ def unfreeze_all(self):
114
+ for p in self.energy_correction.model_charged.parameters():
115
+ p.requires_grad = True
116
+ for p in self.energy_correction.model_neutral.gatr_pid.parameters():
117
+ p.requires_grad = True
118
+ for p in self.energy_correction.model_neutral.PID_head.parameters():
119
+ p.requires_grad = True
120
+
121
+ def training_step(self, batch, batch_idx):
122
+ y = batch[1]
123
+ batch_g = batch[0]
124
+ if self.trainer.is_global_zero:
125
+ result = self(batch_g, y, batch_idx)
126
+ else:
127
+ result = self(batch_g, y, 1)
128
+
129
+ model_output = result[0]
130
+ e_cor = result[1]
131
+ (loss, losses) = object_condensation_loss2(
132
+ batch_g,
133
+ model_output,
134
+ e_cor,
135
+ y,
136
+ q_min=self.args.qmin,
137
+ use_average_cc_pos=self.args.use_average_cc_pos,
138
+ )
139
+ if self.args.correction:
140
+ self.energy_correction.global_step = self.global_step
141
+ fixed = self.current_epoch > 0
142
+ loss_EC, loss_pos, loss_neutral_pid, loss_charged_pid = self.energy_correction.get_loss(
143
+ batch_g, y, result, self.stats, fixed
144
+ )
145
+ loss = loss_EC + loss_neutral_pid + loss_charged_pid
146
+
147
+ if self.trainer.is_global_zero:
148
+ log_losses_wandb(True, batch_idx, 0, losses, loss)
149
+ self.loss_final = loss.item() + self.loss_final
150
+ self.number_b = self.number_b + 1
151
+ del model_output
152
+ del e_cor
153
+ del losses
154
+ return loss
155
+
156
+ def validation_step(self, batch, batch_idx):
157
+ self.create_paths()
158
+ y = batch[1]
159
+ batch_g = batch[0]
160
+ shap_vals, ec_x = None, None
161
+ if self.args.correction:
162
+ result = self(batch_g, y, 1, use_gt_clusters=self.args.use_gt_clusters)
163
+ model_output = result[0]
164
+ outputs = self.energy_correction.get_validation_step_outputs(batch_g, y, result)
165
+ e_cor1, pred_pos, pred_ref_pt, pred_pid, num_fakes, extra_features, fakes_labels = outputs
166
+ e_cor = e_cor1
167
+ else:
168
+ model_output, e_cor1, loss_ll, _ = self(batch_g, y, 1)
169
+ e_cor1 = torch.ones_like(model_output[:, 0].view(-1, 1))
170
+ e_cor = e_cor1
171
+ pred_pos = None
172
+ pred_pid = None
173
+ pred_ref_pt = None
174
+ num_fakes = None
175
+ extra_features = None
176
+ fakes_labels = None
177
+
178
+ if self.args.predict:
179
+ if self.args.correction:
180
+ model_output1 = model_output
181
+ e_corr = e_cor
182
+ else:
183
+ model_output1 = torch.cat((model_output, e_cor.view(-1, 1)), dim=1)
184
+ e_corr = None
185
+
186
+ (
187
+ df_batch_pandora,
188
+ df_batch1,
189
+ self.total_number_events,
190
+ ) = create_and_store_graph_output(
191
+ batch_g,
192
+ model_output1,
193
+ y,
194
+ 0,
195
+ batch_idx,
196
+ 0,
197
+ path_save=self.show_df_eval_path,
198
+ store=True,
199
+ predict=True,
200
+ e_corr=e_corr,
201
+ ec_x=ec_x,
202
+ total_number_events=self.total_number_events,
203
+ pred_pos=pred_pos,
204
+ pred_ref_pt=pred_ref_pt,
205
+ pred_pid=pred_pid,
206
+ use_gt_clusters=self.args.use_gt_clusters,
207
+ number_of_fakes=num_fakes,
208
+ extra_features=extra_features,
209
+ fakes_labels=fakes_labels,
210
+ pandora_available=self.args.pandora,
211
+ )
212
+ self.df_showers_pandora.append(df_batch_pandora)
213
+ self.df_showers_db.append(df_batch1)
214
+ del model_output
215
+
216
+ def create_paths(self):
217
+ show_df_eval_path = os.path.join(self.args.model_prefix, "showers_df_evaluation")
218
+ self.show_df_eval_path = show_df_eval_path
219
+
220
+ def on_train_epoch_end(self):
221
+ self.log("train_loss_epoch", self.loss_final / self.number_b)
222
+
223
+ def on_train_epoch_start(self):
224
+ self.loss_final = 0
225
+ self.number_b = 0
226
+ self.make_mom_zero()
227
+ if self.current_epoch == 0:
228
+ self.stats = {}
229
+ self.stats["counts"] = {}
230
+ self.stats["counts_pid_neutral"] = {}
231
+ self.stats["counts_pid_charged"] = {}
232
+
233
+ def on_validation_epoch_start(self):
234
+ self.total_number_events = 0
235
+ self.make_mom_zero()
236
+ self.df_showers = []
237
+ self.df_showers_pandora = []
238
+ self.df_showers_db = []
239
+ self.validation_step_outputs = []
240
+
241
+ def make_mom_zero(self):
242
+ if self.current_epoch > 1 or self.args.predict:
243
+ print("making momentum 0")
244
+ self.ScaledGooeyBatchNorm2_1.momentum = 0
245
+
246
+ def on_validation_epoch_end(self):
247
+ if self.trainer.is_global_zero:
248
+ if self.args.predict:
249
+ from src.layers.inference_oc import store_at_batch_end
250
+ import pandas as pd
251
+
252
+ if self.args.pandora:
253
+ self.df_showers_pandora = pd.concat(self.df_showers_pandora)
254
+ else:
255
+ self.df_showers_pandora = []
256
+ self.df_showers_db = pd.concat(self.df_showers_db)
257
+ store_at_batch_end(
258
+ path_save=os.path.join(
259
+ self.args.model_prefix, "showers_df_evaluation"
260
+ ) + "/" + self.args.name_output,
261
+ df_batch_pandora=self.df_showers_pandora,
262
+ df_batch1=self.df_showers_db,
263
+ step=0,
264
+ predict=True,
265
+ store=True,
266
+ pandora_available=self.args.pandora
267
+ )
268
+
269
+ self.validation_step_outputs = []
270
+ self.df_showers = []
271
+ self.df_showers_pandora = []
272
+ self.df_showers_db = []
273
+
274
+ def configure_optimizers(self):
275
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.args.start_lr)
276
+ scheduler = CosineAnnealingThenFixedScheduler(optimizer, T_max=int(36400 * 2), fixed_lr=1e-5)
277
+ self.scheduler = scheduler
278
+ return {
279
+ "optimizer": optimizer,
280
+ "lr_scheduler": {
281
+ "scheduler": scheduler,
282
+ "interval": "step",
283
+ "monitor": "train_loss_epoch",
284
+ "frequency": 1,
285
+ },
286
+ }
287
+
288
+ def lr_scheduler_step(self, scheduler, optimizer_idx, metric=None):
289
+ scheduler.step()
290
+
291
+
292
+ def obtain_batch_numbers(g):
293
+ graphs_eval = dgl.unbatch(g)
294
+ number_graphs = len(graphs_eval)
295
+ batch_numbers = []
296
+ for index in range(number_graphs):
297
+ num_nodes = graphs_eval[index].number_of_nodes()
298
+ batch_numbers.append(index * torch.ones(num_nodes))
299
+ return torch.cat(batch_numbers, dim=0)
300
+
301
+
302
+ class CosineAnnealingThenFixedScheduler:
303
+ def __init__(self, optimizer, T_max, fixed_lr):
304
+ self.cosine_scheduler = CosineAnnealingLR(optimizer, T_max=T_max, eta_min=fixed_lr)
305
+ self.fixed_lr = 1e-6
306
+ self.T_max = T_max
307
+ self.step_count = 0
308
+ self.optimizer = optimizer
309
+
310
+ def step(self):
311
+ if self.step_count < self.T_max:
312
+ self.cosine_scheduler.step()
313
+ else:
314
+ for param_group in self.optimizer.param_groups:
315
+ param_group["lr"] = self.fixed_lr
316
+ self.step_count += 1
317
+
318
+ def get_last_lr(self):
319
+ if self.step_count < self.T_max:
320
+ return self.cosine_scheduler.get_last_lr()
321
+ else:
322
+ return [self.fixed_lr for _ in self.optimizer.param_groups]
323
+
324
+ def state_dict(self):
325
+ return {
326
+ "step_count": self.step_count,
327
+ "cosine_scheduler_state": self.cosine_scheduler.state_dict(),
328
+ }
329
+
330
+ def load_state_dict(self, state_dict):
331
+ self.step_count = state_dict["step_count"]
332
+ self.cosine_scheduler.load_state_dict(state_dict["cosine_scheduler_state"])
src/models/energy_correction_NN.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PID + energy correction module.
3
+ The model is called after object condensation clustering to correct
4
+ reconstructed energies and predict particle IDs.
5
+ """
6
+ import numpy as np
7
+ import wandb
8
+ import torch
9
+ from torch.nn import CrossEntropyLoss
10
+ from torch_scatter import scatter_add, scatter_mean
11
+ from typing import NamedTuple, Any
12
+
13
+ from src.layers.utils_training import obtain_clustering_for_matched_showers
14
+ from src.utils.post_clustering_features import (
15
+ get_post_clustering_features, get_extra_features, calculate_eta, calculate_phi,
16
+ )
17
+ from src.utils.pid_conversion import pid_conversion_dict
18
+ from src.layers.regression.loss_regression import obtain_PID_charged, obtain_PID_neutral
19
+ from src.models.energy_correction_charged import ChargedEnergyCorrection
20
+ from src.models.energy_correction_neutral import (
21
+ NeutralEnergyCorrection, criterion_E_cor, correct_mask_neutral,
22
+ )
23
+
24
+
25
+ class _ClusteringOutput(NamedTuple):
26
+ """Structured return type for clustering_and_global_features."""
27
+ graphs: Any # batched DGL graph (feature-augmented)
28
+ batch_idx: torch.Tensor
29
+ high_level_feats: torch.Tensor # per-shower aggregate features
30
+ charged_idx: torch.Tensor
31
+ neutral_idx: torch.Tensor
32
+ feats_charged: torch.Tensor # NaN-zeroed high_level_feats[charged_idx]
33
+ feats_neutral: torch.Tensor # NaN-zeroed high_level_feats[neutral_idx]
34
+ pred_energy: torch.Tensor # ones placeholder, filled by forward_correction
35
+ pred_pos: torch.Tensor
36
+ pred_pid: torch.Tensor
37
+ true: Any
38
+ true_pid: torch.Tensor
39
+ true_coords: torch.Tensor
40
+ sum_e: torch.Tensor
41
+ e_true_daughters: torch.Tensor
42
+ n_fakes: int
43
+ extra_features: torch.Tensor
44
+ fakes_idx: torch.Tensor
45
+
46
+
47
+ def _zero_nans(t: torch.Tensor) -> torch.Tensor:
48
+ out = t.clone()
49
+ out[out != out] = 0
50
+ return out
51
+
52
+
53
+ def _decode_pid(pred_pid: torch.Tensor, pids: list, logits: torch.Tensor, idx: torch.Tensor) -> None:
54
+ if pids and len(idx):
55
+ labels = np.array(pids)[np.argmax(logits.cpu().detach(), axis=1)]
56
+ pred_pid[idx.flatten()] = torch.tensor(labels).long().to(idx.device)
57
+
58
+
59
+ class EnergyCorrection:
60
+ def __init__(self, main_model):
61
+ self.args = main_model.args
62
+ self.get_PID_categories()
63
+ self.get_energy_correction()
64
+ self.pid_conversion_dict = pid_conversion_dict
65
+ self.main_model = main_model
66
+ self.global_step = 0
67
+
68
+ def get_PID_categories(self):
69
+ self.pids_neutral = [2, 3]
70
+ self.pids_charged = [0, 1, 4]
71
+
72
+ def get_energy_correction(self):
73
+ self.model_charged = ChargedEnergyCorrection(args=self.args)
74
+ self.model_neutral = NeutralEnergyCorrection(args=self.args)
75
+
76
+ def clustering_and_global_features(self, g, x, y, add_fakes=True) -> _ClusteringOutput:
77
+ (
78
+ graphs_new, true_new, sum_e, true_pid,
79
+ e_true_corr_daughters, true_coords, number_of_fakes, fakes_idx,
80
+ ) = obtain_clustering_for_matched_showers(
81
+ g, x, y, self.main_model.trainer.global_rank,
82
+ use_gt_clusters=self.args.use_gt_clusters,
83
+ add_fakes=add_fakes,
84
+ )
85
+
86
+ batch_num_nodes = graphs_new.batch_num_nodes()
87
+ batch_idx = []
88
+ for i, n in enumerate(batch_num_nodes):
89
+ batch_idx.extend([i] * n)
90
+ batch_idx = torch.tensor(batch_idx).to(self.main_model.device)
91
+
92
+ graphs_new.ndata["h"][:, 0:3] = graphs_new.ndata["h"][:, 0:3] / 3300
93
+ graphs_sum_features = scatter_add(graphs_new.ndata["h"], batch_idx, dim=0)
94
+ graphs_sum_features = graphs_sum_features[batch_idx]
95
+ betas = torch.sigmoid(graphs_new.ndata["h"][:, -1])
96
+ graphs_new.ndata["h"] = torch.cat(
97
+ (graphs_new.ndata["h"], graphs_sum_features), dim=1
98
+ )
99
+
100
+ high_level = get_post_clustering_features(graphs_new, sum_e)
101
+ extra_features = get_extra_features(graphs_new, betas)
102
+
103
+ dev = graphs_new.ndata["h"].device
104
+ n = high_level.shape[0]
105
+ pred_energy = torch.ones(n, device=dev)
106
+ pred_pos = torch.ones(n, 3, device=dev)
107
+ pred_pid = torch.ones(n, device=dev).long()
108
+
109
+ node_features_avg = scatter_mean(graphs_new.ndata["h"], batch_idx, dim=0)[:, 0:3]
110
+ eta = calculate_eta(node_features_avg[:, 0], node_features_avg[:, 1], node_features_avg[:, 2])
111
+ phi = calculate_phi(node_features_avg[:, 0], node_features_avg[:, 1])
112
+ high_level = torch.cat(
113
+ (high_level, node_features_avg, eta.view(-1, 1), phi.view(-1, 1)), dim=1
114
+ )
115
+
116
+ num_tracks = high_level[:, 7]
117
+ charged_idx = torch.where(num_tracks >= 1)[0]
118
+ neutral_idx = torch.where(num_tracks < 1)[0]
119
+ assert len(charged_idx) + len(neutral_idx) == len(num_tracks)
120
+ assert high_level.shape[0] == graphs_new.batch_num_nodes().shape[0]
121
+
122
+ return _ClusteringOutput(
123
+ graphs=graphs_new,
124
+ batch_idx=batch_idx,
125
+ high_level_feats=high_level,
126
+ charged_idx=charged_idx,
127
+ neutral_idx=neutral_idx,
128
+ feats_charged=_zero_nans(high_level[charged_idx]),
129
+ feats_neutral=_zero_nans(high_level[neutral_idx]),
130
+ pred_energy=pred_energy,
131
+ pred_pos=pred_pos,
132
+ pred_pid=pred_pid,
133
+ true=true_new,
134
+ true_pid=true_pid,
135
+ true_coords=true_coords,
136
+ sum_e=sum_e,
137
+ e_true_daughters=e_true_corr_daughters,
138
+ n_fakes=number_of_fakes,
139
+ extra_features=extra_features,
140
+ fakes_idx=fakes_idx,
141
+ )
142
+
143
+ def forward_correction(self, g, x, y, return_train):
144
+ cf = self.clustering_and_global_features(g, x, y, add_fakes=self.args.predict)
145
+
146
+ charged_energies = self.model_charged.charged_prediction(
147
+ cf.graphs, cf.charged_idx, cf.feats_charged
148
+ )
149
+ neutral_energies, neutral_pxyz_avg = self.model_neutral.neutral_prediction(
150
+ cf.graphs, cf.neutral_idx, cf.feats_neutral
151
+ )
152
+
153
+ if len(self.pids_charged):
154
+ charged_energies, charged_positions, charged_PID_pred, charged_ref_pt_pred = charged_energies
155
+ else:
156
+ charged_energies, charged_positions, _ = charged_energies
157
+ if len(self.pids_neutral):
158
+ neutral_energies, neutral_positions, neutral_PID_pred, neutral_ref_pt_pred = neutral_energies
159
+ else:
160
+ neutral_energies, neutral_positions, _ = neutral_energies
161
+
162
+ cf.pred_energy[cf.charged_idx.flatten()] = charged_energies
163
+ cf.pred_energy[cf.neutral_idx.flatten()] = neutral_energies
164
+
165
+ _decode_pid(cf.pred_pid, self.pids_charged, charged_PID_pred, cf.charged_idx)
166
+ _decode_pid(cf.pred_pid, self.pids_neutral, neutral_PID_pred, cf.neutral_idx)
167
+
168
+ cf.pred_energy[cf.pred_energy < 0] = 0.0
169
+
170
+ pred_ref_pt = torch.ones_like(cf.pred_pos)
171
+ if len(cf.charged_idx):
172
+ pred_ref_pt[cf.charged_idx.flatten()] = charged_ref_pt_pred.to(pred_ref_pt.device)
173
+ cf.pred_pos[cf.charged_idx.flatten()] = charged_positions.float().to(cf.pred_pos.device)
174
+ if len(cf.neutral_idx):
175
+ pred_ref_pt[cf.neutral_idx.flatten()] = neutral_ref_pt_pred.to(cf.neutral_idx.device)
176
+ cf.pred_pos[cf.neutral_idx.flatten()] = neutral_positions.to(cf.neutral_idx.device).float()
177
+
178
+ predictions = {
179
+ "pred_energy_corr": cf.pred_energy,
180
+ "pred_pos": cf.pred_pos,
181
+ "neutrals_idx": cf.neutral_idx.flatten(),
182
+ "charged_idx": cf.charged_idx.flatten(),
183
+ "pred_ref_pt": pred_ref_pt,
184
+ "extra_features": cf.extra_features,
185
+ "fakes_labels": cf.fakes_idx,
186
+ }
187
+ if len(self.pids_charged) or len(self.pids_neutral):
188
+ predictions["pred_PID"] = cf.pred_pid
189
+ predictions["charged_PID_pred"] = charged_PID_pred
190
+ predictions["neutral_PID_pred"] = neutral_PID_pred
191
+
192
+ if return_train:
193
+ return x, predictions, cf.true, cf.sum_e, cf.true_pid, cf.true, cf.true_coords, cf.n_fakes
194
+ else:
195
+ return (
196
+ x, predictions, cf.true, cf.sum_e, cf.graphs, cf.batch_idx,
197
+ cf.high_level_feats, cf.true_pid, cf.e_true_daughters,
198
+ cf.true_coords, cf.n_fakes,
199
+ )
200
+
201
+ def get_loss(self, batch_g, y, result, stats, fixed):
202
+ (
203
+ model_output, dic_e_cor, e_true, e_sum_hits, new_graphs, batch_id,
204
+ graph_level_features, pid_true_matched, e_true_corr_daughters,
205
+ part_coords_matched, num_fakes,
206
+ ) = result
207
+
208
+ e_cor = dic_e_cor["pred_energy_corr"]
209
+ mask_neutral_for_loss = correct_mask_neutral(
210
+ torch.tensor(pid_true_matched), dic_e_cor["neutrals_idx"]
211
+ )
212
+
213
+ e_true_neutrals = e_true[mask_neutral_for_loss]
214
+ e_pred_neutrals = e_cor[mask_neutral_for_loss]
215
+ e_reco_neutrals = e_sum_hits[mask_neutral_for_loss]
216
+ in_distribution = (torch.abs(e_true_neutrals - e_reco_neutrals) / e_true_neutrals) < 0.6
217
+ ypred = e_pred_neutrals[in_distribution]
218
+ ybatch = e_true_neutrals[in_distribution]
219
+
220
+ loss_EC_neutrals = criterion_E_cor(ypred.flatten(), ybatch.flatten()) if len(ypred) > 0 else 0
221
+ wandb.log({"loss_EC_neutrals": loss_EC_neutrals})
222
+
223
+ loss_neutral_pid = 0
224
+ loss_charged_pid = 0
225
+
226
+ if len(self.pids_charged):
227
+ charged_PID_pred, charged_PID_true_onehot, mask_charged = obtain_PID_charged(
228
+ dic_e_cor, pid_true_matched, self.pids_charged, self.args, self.pid_conversion_dict
229
+ )
230
+ loss_charged_pid, acc_charged = pid_loss(
231
+ charged_PID_pred, charged_PID_true_onehot,
232
+ e_true[dic_e_cor["charged_idx"]], mask_charged, fixed, "charged",
233
+ )
234
+ wandb.log({"loss_charged_pid": loss_charged_pid})
235
+
236
+ if len(self.pids_neutral):
237
+ neutral_PID_pred, neutral_PID_true_onehot, mask_neutral = obtain_PID_neutral(
238
+ dic_e_cor, pid_true_matched, self.pids_neutral, self.args, self.pid_conversion_dict
239
+ )
240
+ loss_neutral_pid, acc_neutral = pid_loss(
241
+ neutral_PID_pred, neutral_PID_true_onehot,
242
+ e_true, mask_neutral, fixed, "neutral",
243
+ )
244
+ wandb.log({"loss_neutral_pid": loss_neutral_pid})
245
+
246
+ return loss_EC_neutrals, 0, loss_neutral_pid, loss_charged_pid
247
+
248
+ def get_validation_step_outputs(self, batch_g, y, result):
249
+ (
250
+ model_output, e_cor, e_true, e_sum_hits,
251
+ new_graphs, batch_id, graph_level_features,
252
+ pid_true_matched, e_true_corr_daughters,
253
+ coords_true, num_fakes,
254
+ ) = result
255
+
256
+ if len(self.pids_charged):
257
+ charged_idx = e_cor["charged_idx"]
258
+ if len(self.pids_neutral):
259
+ neutral_idx = e_cor["neutrals_idx"]
260
+ pred_pid = e_cor["pred_PID"]
261
+ charged_PID_pred = e_cor["charged_PID_pred"]
262
+ neutral_PID_pred = e_cor["neutral_PID_pred"]
263
+ pred_pos = e_cor["pred_pos"]
264
+ pred_ref_pt = e_cor["pred_ref_pt"]
265
+ extra_features = e_cor["extra_features"]
266
+ fakes_labels = e_cor["fakes_labels"]
267
+ e_cor = e_cor["pred_energy_corr"]
268
+
269
+ PID_logits = torch.zeros(len(e_cor), len(self.pids_charged) + len(self.pids_neutral)).float()
270
+ PID_logits[charged_idx.cpu(), 0] = charged_PID_pred.detach().cpu()[:, 0]
271
+ PID_logits[charged_idx.cpu(), 1] = charged_PID_pred.detach().cpu()[:, 1]
272
+ PID_logits[charged_idx.cpu(), 4] = charged_PID_pred.detach().cpu()[:, 2]
273
+ PID_logits[neutral_idx.cpu(), 2] = neutral_PID_pred.detach().cpu()[:, 0]
274
+ PID_logits[neutral_idx.cpu(), 3] = neutral_PID_pred.detach().cpu()[:, 1]
275
+
276
+ extra_features = extra_features.detach().cpu()
277
+ extra_features = torch.cat((extra_features, PID_logits), dim=1).numpy()
278
+
279
+ return e_cor, pred_pos, pred_ref_pt, pred_pid, num_fakes, extra_features, fakes_labels
280
+
281
+
282
+ def pid_loss(
283
+ pid_pred_all: torch.Tensor,
284
+ pid_true_all: torch.Tensor,
285
+ e_true: torch.Tensor,
286
+ mask: torch.Tensor,
287
+ frozen: bool = False,
288
+ name: str = "",
289
+ ) -> tuple:
290
+ if not len(pid_pred_all):
291
+ return 0, 0
292
+ mask = mask.bool()
293
+ pid_pred = pid_pred_all[mask]
294
+ pid_true = pid_true_all[mask]
295
+ if not len(pid_pred):
296
+ return 0, 0
297
+ acc = torch.sum(pid_pred == pid_true) / len(pid_pred)
298
+ loss = CrossEntropyLoss()(pid_pred, pid_true)
299
+ return loss, acc
src/models/energy_correction_charged.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ energy_correction_charged.py
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch_scatter import scatter_sum
8
+ from xformers.ops.fmha import BlockDiagonalMask
9
+ import dgl
10
+
11
+ from gatr import GATr, SelfAttentionConfig, MLPConfig
12
+ from gatr.interface import embed_point, embed_scalar
13
+ from src.layers.tools_for_regression import PickPAtDCA
14
+
15
+
16
+ class ChargedEnergyCorrection(nn.Module):
17
+ def __init__(self, args):
18
+ super().__init__()
19
+ self.in_features_global = 16
20
+ self.in_features_gnn = 16 # GATr multivector output dim per batch
21
+ self.pid_channels = [0, 1, 4]
22
+ n_layers = 3
23
+ self.args = args
24
+
25
+ self.gatr = GATr(
26
+ in_mv_channels=1,
27
+ out_mv_channels=1,
28
+ hidden_mv_channels=4,
29
+ in_s_channels=2,
30
+ out_s_channels=None,
31
+ hidden_s_channels=4,
32
+ num_blocks=3,
33
+ attention=SelfAttentionConfig(),
34
+ mlp=MLPConfig(),
35
+ )
36
+
37
+ out_features_gnn = self.in_features_gnn
38
+ in_features_global = self.in_features_global
39
+ n_pid_classes = len(self.pid_channels)
40
+
41
+ pid_layers = [nn.Linear(out_features_gnn + in_features_global + 1, 64)]
42
+ for _ in range(n_layers - 1):
43
+ pid_layers.append(nn.Linear(64, 64))
44
+ pid_layers.append(nn.ReLU())
45
+ pid_layers.append(nn.Linear(64, n_pid_classes))
46
+ self.PID_head = nn.Sequential(*pid_layers)
47
+
48
+ self.PickPAtDCA = PickPAtDCA()
49
+
50
+ def charged_prediction(self, graphs_new, charged_idx, graphs_high_level_features):
51
+ unbatched = dgl.unbatch(graphs_new)
52
+ if len(charged_idx) > 0:
53
+ charged_graphs = dgl.batch([unbatched[i] for i in charged_idx])
54
+ charged_energies = self.predict(
55
+ graphs_high_level_features,
56
+ charged_graphs,
57
+
58
+ )
59
+ else:
60
+ empty = torch.tensor([]).to(graphs_new.ndata["h"].device)
61
+ charged_energies = [empty, empty, empty, empty]
62
+ return charged_energies
63
+
64
+ def predict(self, x_global_features, graphs_new=None):
65
+ """
66
+ Forward pass for charged energy correction.
67
+ :param x_global_features: Global graph-level features (batch, in_features_global)
68
+ :param graphs_new: Batched DGL graph of hit-level data
69
+ :return: (E, direction, pid_pred, ref_pt_pred)
70
+ """
71
+ if graphs_new is not None:
72
+ batch_num_nodes = graphs_new.batch_num_nodes()
73
+ batch_idx = []
74
+ for i, n in enumerate(batch_num_nodes):
75
+ batch_idx.extend([i] * n)
76
+ batch_idx = torch.tensor(batch_idx).to(graphs_new.device)
77
+
78
+ hits_points = graphs_new.ndata["h"][:, 0:3]
79
+ hit_type = graphs_new.ndata["h"][:, 4:8].argmax(dim=1)
80
+ p = graphs_new.ndata["h"][:, 9]
81
+ e = graphs_new.ndata["h"][:, 8]
82
+
83
+ embedded_inputs = embed_point(hits_points) + embed_scalar(hit_type.view(-1, 1))
84
+ extra_scalars = torch.cat([p.unsqueeze(1), e.unsqueeze(1)], dim=1)
85
+ mask = self.build_attention_mask(graphs_new)
86
+ embedded_inputs = embedded_inputs.unsqueeze(-2)
87
+
88
+ embedded_outputs, _ = self.gatr(
89
+ embedded_inputs, scalars=extra_scalars, attention_mask=mask
90
+ )
91
+ embedded_outputs_per_batch = scatter_sum(embedded_outputs[:, 0, :], batch_idx, dim=0)
92
+
93
+ recovered_E = x_global_features[:, 6] / x_global_features[:, 3]
94
+ x_global_features = torch.cat((x_global_features, recovered_E.view(-1, 1)), dim=1)
95
+ model_x = torch.cat([x_global_features, embedded_outputs_per_batch], dim=1)
96
+
97
+ pid_pred = self.PID_head(model_x)
98
+ p_tracks, pos, ref_pt_pred = self.PickPAtDCA.predict(x_global_features, graphs_new)
99
+ E = torch.norm(pos, dim=1)
100
+ pos = (pos / torch.norm(pos, dim=1).unsqueeze(1)).clone()
101
+ return E, pos, pid_pred, ref_pt_pred
102
+
103
+ @staticmethod
104
+ def obtain_batch_numbers(g):
105
+ graphs_eval = dgl.unbatch(g)
106
+ batch_numbers = []
107
+ for index, gj in enumerate(graphs_eval):
108
+ num_nodes = gj.number_of_nodes()
109
+ batch_numbers.append(index * torch.ones(num_nodes))
110
+ return torch.cat(batch_numbers, dim=0)
111
+
112
+ def build_attention_mask(self, g):
113
+ batch_numbers = self.obtain_batch_numbers(g)
114
+ return BlockDiagonalMask.from_seqlens(
115
+ torch.bincount(batch_numbers.long()).tolist()
116
+ )
src/models/energy_correction_neutral.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ energy_correction_neutral.py
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch_scatter import scatter_sum
9
+ from xformers.ops.fmha import BlockDiagonalMask
10
+ import dgl
11
+
12
+ from gatr import GATr, SelfAttentionConfig, MLPConfig
13
+ from gatr.interface import embed_point, embed_scalar
14
+ from src.models.E_correction_module import Net
15
+ from src.layers.tools_for_regression import ECNetWrapperAvg, AverageHitsP
16
+
17
+
18
+ class NeutralEnergyCorrection(nn.Module):
19
+ def __init__(self, args):
20
+ super().__init__()
21
+ self.in_features_global = 16
22
+ self.in_features_gnn = 16 # GATr multivector output dim per batch
23
+ self.pid_channels = [2, 3]
24
+ self.args = args
25
+ n_layers = 3
26
+
27
+ gatr_kwargs = dict(
28
+ in_mv_channels=1,
29
+ out_mv_channels=1,
30
+ hidden_mv_channels=4,
31
+ in_s_channels=2,
32
+ out_s_channels=None,
33
+ hidden_s_channels=4,
34
+ num_blocks=3,
35
+ attention=SelfAttentionConfig(),
36
+ mlp=MLPConfig(),
37
+ )
38
+ self.gatr = GATr(**gatr_kwargs)
39
+ self.gatr_pid = GATr(**gatr_kwargs)
40
+
41
+ out_features_gnn = self.in_features_gnn
42
+ in_features_global = self.in_features_global
43
+ n_pid_classes = len(self.pid_channels)
44
+ out_f = 1 # Energy prediction (scalar)
45
+
46
+ pid_layers = [nn.Linear(out_features_gnn + in_features_global, 64)]
47
+ for _ in range(n_layers - 1):
48
+ pid_layers.append(nn.Linear(64, 64))
49
+ pid_layers.append(nn.ReLU())
50
+ pid_layers.append(nn.Linear(64, n_pid_classes))
51
+ self.PID_head = nn.Sequential(*pid_layers)
52
+
53
+ self.model = Net(
54
+ in_features=out_features_gnn + in_features_global,
55
+ out_features=out_f,
56
+ return_raw=True,
57
+ )
58
+ self.ec_model_wrapper_neutral_avg = ECNetWrapperAvg()
59
+ self.AvgHits = AverageHitsP(ecal_only=True)
60
+
61
+ def neutral_prediction(self, graphs_new, neutral_idx, features_neutral_no_nan):
62
+ unbatched = dgl.unbatch(graphs_new)
63
+ if len(neutral_idx) > 0:
64
+ neutral_graphs = dgl.batch([unbatched[i] for i in neutral_idx])
65
+ neutral_energies = self.predict(
66
+ features_neutral_no_nan,
67
+ neutral_graphs,
68
+ )
69
+ neutral_pxyz_avg = self.ec_model_wrapper_neutral_avg.predict(
70
+ features_neutral_no_nan,
71
+ neutral_graphs,
72
+ )[1]
73
+ else:
74
+ empty = torch.tensor([]).to(graphs_new.ndata["h"].device)
75
+ neutral_energies = [empty, empty, empty, empty]
76
+ neutral_pxyz_avg = empty
77
+ return neutral_energies, neutral_pxyz_avg
78
+
79
+ def predict(self, x_global_features, graphs_new=None):
80
+ """
81
+ Forward pass for neutral energy correction.
82
+ :param x_global_features: Global graph-level features (batch, in_features_global)
83
+ :param graphs_new: Batched DGL graph of hit-level data
84
+ :return: (E_pred, direction, pid_pred, ref_pt_pred)
85
+ """
86
+ if graphs_new is not None:
87
+ batch_num_nodes = graphs_new.batch_num_nodes()
88
+ batch_idx = []
89
+ for i, n in enumerate(batch_num_nodes):
90
+ batch_idx.extend([i] * n)
91
+ batch_idx = torch.tensor(batch_idx).to(graphs_new.device)
92
+
93
+ hits_points = graphs_new.ndata["h"][:, 0:3]
94
+ hit_type = graphs_new.ndata["h"][:, 4:8].argmax(dim=1)
95
+ p = graphs_new.ndata["h"][:, 9]
96
+ e = graphs_new.ndata["h"][:, 8]
97
+
98
+ embedded_inputs = embed_point(hits_points) + embed_scalar(hit_type.view(-1, 1))
99
+ extra_scalars = torch.cat([p.unsqueeze(1), e.unsqueeze(1)], dim=1)
100
+ mask = self.build_attention_mask(graphs_new)
101
+ embedded_inputs = embedded_inputs.unsqueeze(-2)
102
+
103
+ embedded_outputs, _ = self.gatr(
104
+ embedded_inputs, scalars=extra_scalars, attention_mask=mask
105
+ )
106
+ embedded_outputs_per_batch = scatter_sum(embedded_outputs[:, 0, :], batch_idx, dim=0)
107
+ model_x = torch.cat([x_global_features, embedded_outputs_per_batch], dim=1)
108
+
109
+ embedded_outputs_pid, _ = self.gatr_pid(
110
+ embedded_inputs, scalars=extra_scalars, attention_mask=mask
111
+ )
112
+ embedded_outputs_per_batch_pid = scatter_sum(
113
+ embedded_outputs_pid[:, 0, :], batch_idx, dim=0
114
+ )
115
+ model_x_pid = torch.cat([x_global_features, embedded_outputs_per_batch_pid], dim=1)
116
+
117
+ res = self.model(model_x)
118
+ pid_pred = self.PID_head(model_x_pid)
119
+ E_pred = res[:, 0]
120
+
121
+ _, p_pred, ref_pt_pred = self.AvgHits.predict(x_global_features, graphs_new)
122
+ p_pred = (p_pred / torch.norm(p_pred, dim=1).unsqueeze(1)).clone()
123
+ return E_pred, p_pred, pid_pred, ref_pt_pred
124
+
125
+ @staticmethod
126
+ def obtain_batch_numbers(g):
127
+ graphs_eval = dgl.unbatch(g)
128
+ batch_numbers = []
129
+ for index, gj in enumerate(graphs_eval):
130
+ num_nodes = gj.number_of_nodes()
131
+ batch_numbers.append(index * torch.ones(num_nodes))
132
+ return torch.cat(batch_numbers, dim=0)
133
+
134
+ def build_attention_mask(self, g):
135
+ batch_numbers = self.obtain_batch_numbers(g)
136
+ return BlockDiagonalMask.from_seqlens(
137
+ torch.bincount(batch_numbers.long()).tolist()
138
+ )
139
+
140
+
141
+ def correct_mask_neutral(pid_neutral, neural_mask):
142
+ """
143
+ Filter neutral-candidate indices to keep only genuine neutral PIDs.
144
+ """
145
+ pid_neutral = pid_neutral.to(neural_mask.device)
146
+ pid_neutral = torch.abs(pid_neutral)
147
+ keep_list = torch.tensor([22, 130, 2112], device=pid_neutral.device)
148
+ selected_pids = pid_neutral[neural_mask]
149
+ keep_mask = torch.isin(selected_pids, keep_list)
150
+ return neural_mask[keep_mask.to(neural_mask.device)]
151
+
152
+
153
+ def criterion_E_cor(ypred, ytrue):
154
+ if len(ypred) > 0:
155
+ return torch.mean(F.l1_loss(ypred, ytrue, reduction="none"))
156
+ else:
157
+ return 0
src/models/wrapper/example_mode_gatr_noise.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from src.models.Gatr_pf_e_noise import ExampleWrapper
3
+
4
+
5
+ class GraphTransformerNetWrapper(torch.nn.Module):
6
+ def __init__(self, args, dev, **kwargs) -> None:
7
+ super().__init__()
8
+ self.mod = ExampleWrapper(args, dev, **kwargs)
9
+
10
+ def forward(self, g, y, step_count, **kwargs):
11
+ return self.mod(g, y, step_count, **kwargs)
12
+
13
+
14
+ def get_model(data_config, args, dev, **kwargs):
15
+ model = GraphTransformerNetWrapper(args, dev, **kwargs)
16
+ model_info = {}
17
+ return model, model_info
18
+
19
+
20
+ def get_loss(data_config, **kwargs):
21
+ return torch.nn.MSELoss()
src/train_lightning1.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import sys
5
+ import glob
6
+ import torch
7
+ import lightning as L
8
+ from lightning.pytorch.loggers import WandbLogger
9
+
10
+ sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
11
+
12
+ from src.utils.parser_args import parser
13
+ from src.utils.train_utils import (
14
+ train_load,
15
+ test_load,
16
+ get_samples_steps_per_epoch,
17
+ model_setup,
18
+ set_gpus,
19
+ )
20
+ from src.utils.load_pretrained_models import (
21
+ load_train_model,
22
+ load_test_model,
23
+ )
24
+ from src.utils.callbacks import (
25
+ get_callbacks,
26
+ get_callbacks_eval,
27
+ )
28
+
29
+
30
+ # ----------------------------------------------------------------------
31
+ # Helpers
32
+ # ----------------------------------------------------------------------
33
+
34
+ def setup_wandb(args):
35
+ return WandbLogger(
36
+ project=args.wandb_projectname,
37
+ entity=args.wandb_entity,
38
+ name=args.wandb_displayname,
39
+ log_model="all",
40
+ )
41
+
42
+
43
+ def build_trainer(args, gpus, logger, training=True):
44
+ callbacks = get_callbacks(args) if training else get_callbacks_eval(args)
45
+
46
+ strategy = "auto" if args.correction else "ddp" if training else None
47
+
48
+ return L.Trainer(
49
+ callbacks=callbacks,
50
+ accelerator="gpu",
51
+ devices=gpus,
52
+ default_root_dir=args.model_prefix,
53
+ logger=logger,
54
+ max_epochs=args.num_epochs if training else None,
55
+ strategy=strategy,
56
+ limit_train_batches=args.train_batches if training else None,
57
+ limit_val_batches=5 if training else None,
58
+ )
59
+
60
+
61
+ # ----------------------------------------------------------------------
62
+ # Main
63
+ # ----------------------------------------------------------------------
64
+
65
+ def main():
66
+ args = parser.parse_args()
67
+ torch.autograd.set_detect_anomaly(True)
68
+
69
+ training_mode = not args.predict
70
+ args.local_rank = 0
71
+
72
+ # --------------------------------------------------
73
+ # Data
74
+ # --------------------------------------------------
75
+ args = get_samples_steps_per_epoch(args)
76
+
77
+ if training_mode:
78
+ args.data_train = glob.glob(args.data_train[0] + "*.parquet")
79
+ train_loader, val_loader, data_config, train_input_names = train_load(args)
80
+ else:
81
+ test_loaders, data_config = test_load(args)
82
+
83
+ # --------------------------------------------------
84
+ # Model & devices
85
+ # --------------------------------------------------
86
+ model = model_setup(args, data_config)
87
+ gpus, dev = set_gpus(args)
88
+
89
+ if training_mode and args.load_model_weights:
90
+ model = load_train_model(args, dev)
91
+
92
+ # --------------------------------------------------
93
+ # Logger
94
+ # --------------------------------------------------
95
+ wandb_logger = setup_wandb(args)
96
+
97
+ # --------------------------------------------------
98
+ # Training
99
+ # --------------------------------------------------
100
+ if training_mode:
101
+ trainer = build_trainer(args, gpus, wandb_logger, training=True)
102
+ args.local_rank = trainer.global_rank
103
+
104
+ trainer.fit(
105
+ model=model,
106
+ train_dataloaders=train_loader,
107
+ val_dataloaders=val_loader,
108
+ )
109
+
110
+ # --------------------------------------------------
111
+ # Evaluation
112
+ # --------------------------------------------------
113
+ if args.data_test:
114
+ if args.load_model_weights:
115
+ model = load_test_model(args, dev)
116
+
117
+ trainer = build_trainer(args, gpus, wandb_logger, training=False)
118
+
119
+ for name, get_test_loader in test_loaders.items():
120
+ test_loader = get_test_loader()
121
+ trainer.validate(
122
+ model=model,
123
+ dataloaders=test_loader,
124
+ )
125
+
126
+
127
+ if __name__ == "__main__":
128
+ main()
src/utils/callbacks.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from lightning.pytorch.callbacks import (
3
+ TQDMProgressBar,
4
+ ModelCheckpoint,
5
+ LearningRateMonitor,
6
+ )
7
+ from src.layers.utils_training import FreezeClustering
8
+
9
+ def get_callbacks(args):
10
+ checkpoint_callback = ModelCheckpoint(
11
+ dirpath=args.model_prefix, # checkpoints_path, # <--- specify this on the trainer itself for version control
12
+ filename="_{epoch}_{step}",
13
+ # every_n_epochs=val_every_n_epochs,
14
+ every_n_train_steps=500,
15
+ save_top_k=-1, # <--- this is important!
16
+ save_weights_only=True,
17
+ )
18
+ lr_monitor = LearningRateMonitor(logging_interval="epoch")
19
+ callbacks = [
20
+ TQDMProgressBar(refresh_rate=10),
21
+ checkpoint_callback,
22
+ lr_monitor,
23
+ ]
24
+ if args.freeze_clustering:
25
+ callbacks.append(FreezeClustering())
26
+ return callbacks
27
+
28
+ def get_callbacks_eval(args):
29
+ callbacks=[TQDMProgressBar(refresh_rate=1)]
30
+ return callbacks
src/utils/import_tools.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from importlib.util import spec_from_file_location, module_from_spec
2
+
3
+
4
+ def import_module(path, name='_mod'):
5
+ spec = spec_from_file_location(name, path)
6
+ mod = module_from_spec(spec)
7
+ spec.loader.exec_module(mod)
8
+ return mod
src/utils/inference/pandas_helpers.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import pickle
3
+ import mplhep as hep
4
+ from src.utils.pid_conversion import pid_conversion_dict
5
+
6
+ #hep.style.use("CMS")
7
+ import matplotlib
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+
13
+
14
+ def open_mlpf_dataframe(path_mlpf, neutrals_only=False, charged_only=False):
15
+ data = pd.read_pickle(path_mlpf)
16
+ sd = data
17
+ sd["pid_4_class_true"] = sd["pid"].map(pid_conversion_dict)
18
+ if "pred_pid_matched" in sd.columns:
19
+ sd.loc[sd["pred_pid_matched"] < -1, "pred_pid_matched"] = np.nan
20
+ return sd
21
+
22
+ def concat_with_batch_fix(dfs, batch_key="number_batch"):
23
+
24
+ corrected_dfs = []
25
+ batch_offset = 0
26
+
27
+ for df in dfs:
28
+ df = df.copy()
29
+ if batch_key in df.columns:
30
+ df[batch_key] = df[batch_key] + batch_offset
31
+ batch_offset = df[batch_key].max() + 1
32
+ else:
33
+ raise KeyError(f"'{batch_key}' not found in one of the DataFrames.")
34
+ corrected_dfs.append(df)
35
+ return pd.concat(corrected_dfs, ignore_index=True)
36
+
src/utils/load_pretrained_models.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+ def load_train_model(args, dev):
5
+ from src.models.Gatr_pf_e_noise import ExampleWrapper as GravnetModel
6
+ model = GravnetModel.load_from_checkpoint(
7
+ args.load_model_weights, args=args, dev=0, map_location=dev,strict=False)
8
+ return model
9
+
10
+ def load_test_model(args, dev):
11
+ if args.load_model_weights is not None and (not args.correction):
12
+ from src.models.Gatr_pf_e_noise import ExampleWrapper as GravnetModel
13
+ model = GravnetModel.load_from_checkpoint(
14
+ args.load_model_weights, args=args, dev=0, map_location=dev, strict=False
15
+ )
16
+
17
+ if args.load_model_weights is not None and args.correction:
18
+ from src.models.Gatr_pf_e_noise import ExampleWrapper as GravnetModel
19
+ ckpt = torch.load(args.load_model_weights, map_location=dev)
20
+
21
+ state_dict = ckpt["state_dict"]
22
+ model = GravnetModel( args=args, dev=0)
23
+ model.load_state_dict(state_dict, strict=False)
24
+
25
+ model2 = GravnetModel.load_from_checkpoint(args.load_model_weights_clustering, args=args, dev=0, strict=False, map_location=torch.device("cuda:0"))
26
+ model.gatr = model2.gatr
27
+ model.ScaledGooeyBatchNorm2_1 = model2.ScaledGooeyBatchNorm2_1
28
+ model.clustering = model2.clustering
29
+ model.beta = model2.beta
30
+ model.eval()
31
+ return model
32
+
src/utils/logger_wandb.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
+ import numpy as np
3
+ import torch
4
+ from sklearn.metrics import roc_curve, roc_auc_score
5
+ import json
6
+ import dgl
7
+ import matplotlib.pyplot as plt
8
+ from sklearn.decomposition import PCA
9
+ from torch_scatter import scatter_max
10
+ from matplotlib.cm import ScalarMappable
11
+ from matplotlib.colors import Normalize
12
+
13
+
14
+ def log_losses_wandb(
15
+ logwandb, num_batches, local_rank, losses, loss, val=False
16
+ ):
17
+ if val:
18
+ val_ = " val"
19
+ else:
20
+ val_ = ""
21
+ if logwandb and ((num_batches - 1) % 10) == 0 and local_rank == 0:
22
+ wandb.log(
23
+ {
24
+ "loss" + val_ + " regression": loss,
25
+ "loss" + val_ + " lv": losses[0],
26
+ "loss" + val_ + " beta": losses[1],
27
+ "loss" + val_ + " beta sig": losses[2],
28
+ "loss" + val_ + " beta noise": losses[3],
29
+ "loss" + val_ + " attractive": losses[12],
30
+ "loss" + val_ + " repulsive": losses[13],
31
+ }
32
+ )
33
+
src/utils/parser_args.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ parser = argparse.ArgumentParser()
4
+
5
+ parser.add_argument(
6
+ "--freeze-clustering",
7
+ action="store_true",
8
+ default=False,
9
+ help="Freeze the clustering part of the model",
10
+ )
11
+
12
+
13
+ parser.add_argument("-c", "--data-config", type=str, help="data config YAML file")
14
+
15
+ parser.add_argument(
16
+ "-i",
17
+ "--data-train",
18
+ nargs="*",
19
+ default=[],
20
+ help="training files; supported syntax:"
21
+ " (a) plain list, `--data-train /path/to/a/* /path/to/b/*`;"
22
+ " (b) (named) groups [Recommended], `--data-train a:/path/to/a/* b:/path/to/b/*`,"
23
+ " the file splitting (for each dataloader worker) will be performed per group,"
24
+ " and then mixed together, to ensure a uniform mixing from all groups for each worker.",
25
+ )
26
+ parser.add_argument(
27
+ "-l",
28
+ "--data-val",
29
+ nargs="*",
30
+ default=[],
31
+ help="validation files; when not set, will use training files and split by `--train-val-split`",
32
+ )
33
+ parser.add_argument(
34
+ "-t",
35
+ "--data-test",
36
+ nargs="*",
37
+ default=[],
38
+ help="testing files; supported syntax:"
39
+ " (a) plain list, `--data-test /path/to/a/* /path/to/b/*`;"
40
+ " (b) keyword-based, `--data-test a:/path/to/a/* b:/path/to/b/*`, will produce output_a, output_b;"
41
+ " (c) split output per N input files, `--data-test a%10:/path/to/a/*`, will split per 10 input files",
42
+ )
43
+
44
+ parser.add_argument(
45
+ "--data-fraction",
46
+ type=float,
47
+ default=1,
48
+ help="fraction of events to load from each file; for training, the events are randomly selected for each epoch",
49
+ )
50
+ parser.add_argument(
51
+ "--file-fraction",
52
+ type=float,
53
+ default=1,
54
+ help="fraction of files to load; for training, the files are randomly selected for each epoch",
55
+ )
56
+ parser.add_argument(
57
+ "--fetch-by-files",
58
+ action="store_true",
59
+ default=False,
60
+ help="When enabled, will load all events from a small number (set by ``--fetch-step``) of files for each data fetching. "
61
+ "Otherwise (default), load a small fraction of events from all files each time, which helps reduce variations in the sample composition.",
62
+ )
63
+ parser.add_argument(
64
+ "--fetch-step",
65
+ type=float,
66
+ default=0.01,
67
+ help="fraction of events to load each time from every file (when ``--fetch-by-files`` is disabled); "
68
+ "Or: number of files to load each time (when ``--fetch-by-files`` is enabled). Shuffling & sampling is done within these events, so set a large enough value.",
69
+ )
70
+
71
+ parser.add_argument(
72
+ "--train-val-split",
73
+ type=float,
74
+ default=0.8,
75
+ help="training/validation split fraction",
76
+ )
77
+
78
+
79
+ parser.add_argument(
80
+ "-n",
81
+ "--network-config",
82
+ type=str,
83
+ help="network architecture configuration file; the path must be relative to the current dir",
84
+ )
85
+ parser.add_argument(
86
+ "-m",
87
+ "--model-prefix",
88
+ type=str,
89
+ default="models/{auto}/networkss",
90
+ help="path to save or load the model; for training, this will be used as a prefix, so model snapshots "
91
+ "will saved to `{model_prefix}_epoch-%d_state.pt` after each epoch, and the one with the best "
92
+ "validation metric to `{model_prefix}_best_epoch_state.pt`; for testing, this should be the full path "
93
+ "including the suffix, otherwise the one with the best validation metric will be used; "
94
+ "for training, `{auto}` can be used as part of the path to auto-generate a name, "
95
+ "based on the timestamp and network configuration",
96
+ )
97
+
98
+ parser.add_argument(
99
+ "--load-model-weights",
100
+ type=str,
101
+ default=None,
102
+ help="initialize model with pre-trained weights",
103
+ )
104
+ parser.add_argument(
105
+ "--load-model-weights-clustering",
106
+ type=str,
107
+ default=None,
108
+ help="initialize model with pre-trained weights for clustering part of the model",
109
+ )
110
+ parser.add_argument("--start-lr", type=float, default=5e-3, help="start learning rate")
111
+
112
+ parser.add_argument("--num-epochs", type=int, default=20, help="number of epochs")
113
+ parser.add_argument(
114
+ "--steps-per-epoch",
115
+ type=int,
116
+ default=None,
117
+ help="number of steps (iterations) per epochs; "
118
+ "if neither of `--steps-per-epoch` or `--samples-per-epoch` is set, each epoch will run over all loaded samples",
119
+ )
120
+ parser.add_argument(
121
+ "--steps-per-epoch-val",
122
+ type=int,
123
+ default=None,
124
+ help="number of steps (iterations) per epochs for validation; "
125
+ "if neither of `--steps-per-epoch-val` or `--samples-per-epoch-val` is set, each epoch will run over all loaded samples",
126
+ )
127
+ parser.add_argument(
128
+ "--samples-per-epoch",
129
+ type=int,
130
+ default=None,
131
+ help="number of samples per epochs; "
132
+ "if neither of `--steps-per-epoch` or `--samples-per-epoch` is set, each epoch will run over all loaded samples",
133
+ )
134
+ parser.add_argument(
135
+ "--samples-per-epoch-val",
136
+ type=int,
137
+ default=None,
138
+ help="number of samples per epochs for validation; "
139
+ "if neither of `--steps-per-epoch-val` or `--samples-per-epoch-val` is set, each epoch will run over all loaded samples",
140
+ )
141
+ parser.add_argument("--batch-size", type=int, default=128, help="batch size")
142
+
143
+ parser.add_argument(
144
+ "--gpus",
145
+ type=str,
146
+ default="0",
147
+ help='device for the training/testing; to use CPU, set to empty string (""); to use multiple gpu, set it as a comma separated list, e.g., `1,2,3,4`',
148
+ )
149
+
150
+ parser.add_argument(
151
+ "--num-workers",
152
+ type=int,
153
+ default=1,
154
+ help="number of threads to load the dataset; memory consumption and disk access load increases (~linearly) with this numbers",
155
+ )
156
+ parser.add_argument(
157
+ "--prefetch-factor",
158
+ type=int,
159
+ default=1,
160
+ help="How many items to prefetch in the dataloaders. Should be about the same order of magnitude as batch size for optimal performance.",
161
+ )
162
+ parser.add_argument(
163
+ "--predict",
164
+ action="store_true",
165
+ default=False,
166
+ help="run prediction instead of training",
167
+ )
168
+
169
+
170
+
171
+
172
+ parser.add_argument(
173
+ "--log-wandb", action="store_true", default=False, help="use wandb for loging"
174
+ )
175
+ parser.add_argument(
176
+ "--wandb-displayname",
177
+ type=str,
178
+ help="give display name to wandb run, if not entered a random one is generated",
179
+ )
180
+ parser.add_argument(
181
+ "--wandb-projectname", type=str, help="project where the run is stored inside wandb"
182
+ )
183
+ parser.add_argument(
184
+ "--wandb-entity", type=str, help="username or team name where you are sending runs"
185
+ )
186
+
187
+
188
+ parser.add_argument(
189
+ "--qmin", type=float, default=0.1, help="define qmin for condensation"
190
+ )
191
+
192
+
193
+ parser.add_argument(
194
+ "--frac_cluster_loss",
195
+ type=float,
196
+ default=0,
197
+ help="Fraction of total pairs to use for the clustering loss",
198
+ )
199
+
200
+
201
+
202
+
203
+
204
+ parser.add_argument(
205
+ "--use-average-cc-pos",
206
+ default=0.0,
207
+ type=float,
208
+ help="push the alpha to the mean of the coordinates in the object by this value",
209
+ )
210
+
211
+
212
+ parser.add_argument(
213
+ "--correction",
214
+ action="store_true",
215
+ default=False,
216
+ help="Train correction only",
217
+ )
218
+
219
+
220
+
221
+
222
+ parser.add_argument(
223
+ "--use-gt-clusters",
224
+ default=False,
225
+ action="store_true",
226
+ help="If toggled, uses ground-truth clusters instead of the predicted ones by the model. We can use this to simulate 'ideal' clustering.",
227
+ )
228
+
229
+
230
+ parser.add_argument(
231
+ "--name-output",
232
+ type=str,
233
+ help="name of the dataframe stored during eval",
234
+ )
235
+ parser.add_argument(
236
+ "--train-batches",
237
+ default=100,
238
+ type=int,
239
+ help="number of train batches",
240
+ )
241
+ parser.add_argument(
242
+ "--pandora",
243
+ default=False,
244
+ action="store_true",
245
+ help="using pandora information",
246
+ )
src/utils/pid_conversion.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # A global variable, so it doesn't have to be modified in 10 different places when new particles are added
2
+
3
+ pid_conversion_dict = {11: 0, -11: 0, 211: 1, -211: 1, 130: 2, -130: 2, 2112: 2, -2112: 2, 22: 3, 321: 1, -321: 1, 2212: 1, -2212: 1, 310: 2, -310: 2, 3122: 2, -3122: 2, 3212: 2, -3212: 2, 3112: 1, -3112: 1, 3222: 1, -3222: 1, 3224: 1, -3224: 1, 3312: 2, -3312: 2, 13: 4, -13: 4, 3322: 2, -3322: 2, 1000020030.0: 2, 1000010050.0: 2, 1000010048.0: 2, 3334: 1, -3334:1, 1000020032.0: 2, 1000080128.0: 2, 1000110208.0: 2, 1000040064.0: 2, 1000070144.0: 2, 1000010020.0:2, 1000010030.0:2, 1000020040.0:2}
4
+
5
+ pandora_to_our_mapping = {211: 1, -211: 1, -13: 4, 13: 4, 11: 0, -11: 0, 22: 3, 2112: 2, 130: 2, -2112: 2}
6
+ our_to_pandora_mapping = {0: [11, -11], 1: [211, -211,2212, -2212, 321, -321, 3222, 3112, 3224, -3112, -3224], 2: [2112, 130, 310, 3122, 3212], 3: [22], 4:[13,-13]}
7
+
src/utils/post_clustering_features.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_scatter import scatter_sum, scatter_std
3
+
4
+ def calculate_phi(x, y, z=None):
5
+ return torch.arctan2(y, x)
6
+
7
+ def calculate_eta(x, y, z):
8
+ theta = torch.arctan2(torch.sqrt(x ** 2 + y ** 2), z)
9
+ return -torch.log(torch.tan(theta / 2))
10
+
11
+ def get_post_clustering_features(graphs_new, sum_e):
12
+ '''
13
+ Obtain graph-level qualitative features that can then be used to regress the energy corr. factor.
14
+ :param graph_batch: Output from the previous step - clustered, matched showers
15
+ :return:
16
+ '''
17
+ batch_num_nodes = graphs_new.batch_num_nodes() # Num. of hits in each graph
18
+ batch_idx = []
19
+ for i, n in enumerate(batch_num_nodes):
20
+ batch_idx.extend([i] * n)
21
+ batch_idx = torch.tensor(batch_idx).to(graphs_new.device)
22
+ e_hits = graphs_new.ndata["h"][:, 8]
23
+
24
+ muon_hits = graphs_new.ndata["h"][:, 7]
25
+ filter_muon = torch.where(muon_hits)[0]
26
+ per_graph_e_hits_muon = scatter_sum(e_hits[filter_muon], batch_idx[filter_muon], dim_size=batch_idx.max() + 1)
27
+ per_graph_n_hits_muon = scatter_sum((e_hits[filter_muon] > 0).type(torch.int), batch_idx[filter_muon], dim_size=batch_idx.max() + 1)
28
+ ecal_hits = graphs_new.ndata["h"][:, 5]
29
+ filter_ecal = torch.where(ecal_hits)[0]
30
+ hcal_hits = graphs_new.ndata["h"][:, 6]
31
+ filter_hcal = torch.where(hcal_hits)[0]
32
+ per_graph_e_hits_ecal = scatter_sum(e_hits[filter_ecal], batch_idx[filter_ecal], dim_size=batch_idx.max() + 1)
33
+ # similar as above but with scatter_std
34
+ per_graph_e_hits_ecal_dispersion = scatter_std(e_hits[filter_ecal], batch_idx[filter_ecal], dim_size=batch_idx.max() + 1) ** 2
35
+ per_graph_e_hits_hcal = scatter_sum(e_hits[filter_hcal], batch_idx[filter_hcal], dim_size=batch_idx.max() + 1)
36
+ # similar as above but with scatter_std -- !!!!! TODO: Retrain the base EC models using this definition !!!!!
37
+ per_graph_e_hits_hcal_dispersion = scatter_std(e_hits[filter_hcal], batch_idx[filter_hcal], dim_size=batch_idx.max() + 1) ** 2
38
+ # track_nodes =
39
+ track_p = scatter_sum(graphs_new.ndata["h"][:, 9], batch_idx)
40
+ chis_tracks = scatter_sum(graphs_new.ndata["chi_squared_tracks"], batch_idx)
41
+ num_tracks = scatter_sum((graphs_new.ndata["h"][:, 9] > 0).type(torch.int), batch_idx)
42
+ track_p = track_p / num_tracks
43
+ track_p[num_tracks == 0] = 0.
44
+ chis_tracks = chis_tracks / num_tracks
45
+ num_hits = graphs_new.batch_num_nodes()
46
+ # print shapes of the below things
47
+
48
+ return torch.nan_to_num(
49
+ torch.stack([per_graph_e_hits_ecal / sum_e,
50
+ per_graph_e_hits_hcal / sum_e,
51
+ num_hits, track_p,
52
+ per_graph_e_hits_ecal_dispersion,
53
+ per_graph_e_hits_hcal_dispersion,
54
+ sum_e, num_tracks, torch.clamp(chis_tracks, -5, 5),
55
+ per_graph_e_hits_muon,
56
+ per_graph_n_hits_muon
57
+ ]).T
58
+ )
59
+
60
+
61
+
62
+ def get_extra_features(graphs_new, betas):
63
+ '''
64
+ Obtain extra graph-level features for debugging of the fakes
65
+ '''
66
+ batch_num_nodes = graphs_new.batch_num_nodes() # Num. of hits in each graph
67
+ batch_idx = []
68
+ topk_highest_betas = []
69
+ for i, n in enumerate(batch_num_nodes):
70
+ batch_idx.extend([i] * n)
71
+ batch_idx = torch.tensor(batch_idx).to(graphs_new.device)
72
+ n_highest_betas = 1
73
+ for i in range(len(batch_num_nodes)):
74
+ betas_i = betas[batch_idx == i]
75
+ topk_betas = torch.topk(betas_i, n_highest_betas)
76
+ if len(topk_betas.values) < n_highest_betas:
77
+ topk_betas = torch.cat([topk_betas.values, torch.zeros(n_highest_betas - len(topk_betas.values))])
78
+ topk_highest_betas.append(topk_betas.values)
79
+ topk_highest_betas = torch.stack(topk_highest_betas)
80
+ # Concat with batch_num_nodes
81
+ features = torch.cat([batch_num_nodes.view(-1, 1), topk_highest_betas], dim=1)
82
+ return features
src/utils/train_utils.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ast
3
+ import sys
4
+ import shutil
5
+ import glob
6
+ import functools
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+ from src.dataset.dataset import SimpleIterDataset
11
+ from src.utils.import_tools import import_module
12
+ from src.dataset.functions_graph import graph_batch_func
13
+
14
+ def set_gpus(args):
15
+ if args.gpus:
16
+ gpus = [int(i) for i in args.gpus.split(",")]
17
+ dev = torch.device(gpus[0])
18
+ print("Using GPUs:", gpus)
19
+ else:
20
+ print("No GPUs flag provided - Setting GPUs to [0]")
21
+ gpus = [0]
22
+ dev = torch.device(gpus[0])
23
+ raise Exception("Please provide GPU number")
24
+ return gpus, dev
25
+
26
+
27
+
28
+ def get_gpu_dev(args):
29
+ if args.gpus != "":
30
+ accelerator = "gpu"
31
+ devices = args.gpus
32
+ else:
33
+ accelerator = 0
34
+ devices = 0
35
+ return accelerator, devices
36
+ # TODO change this to use it from config file
37
+
38
+ def model_setup(args, data_config):
39
+ """
40
+ Loads the model
41
+ :param args:
42
+ :param data_config:
43
+ :return: model, model_info, network_module
44
+ """
45
+ network_module = import_module(args.network_config, name="_network_module")
46
+
47
+ if args.gpus:
48
+ gpus = [int(i) for i in args.gpus.split(",")] # ?
49
+ dev = torch.device(gpus[0])
50
+ print("using GPUs:", gpus)
51
+ else:
52
+ gpus = None
53
+ local_rank = 0
54
+ dev = torch.device("cpu")
55
+ model, model_info = network_module.get_model(
56
+ data_config, args=args, dev=dev
57
+ )
58
+ return model.mod
59
+
60
+
61
+ def get_samples_steps_per_epoch(args):
62
+ if args.samples_per_epoch is not None:
63
+ if args.steps_per_epoch is None:
64
+ args.steps_per_epoch = args.samples_per_epoch // args.batch_size
65
+ else:
66
+ raise RuntimeError(
67
+ "Please use either `--steps-per-epoch` or `--samples-per-epoch`, but not both!"
68
+ )
69
+ if args.samples_per_epoch_val is not None:
70
+ if args.steps_per_epoch_val is None:
71
+ args.steps_per_epoch_val = args.samples_per_epoch_val // args.batch_size
72
+ else:
73
+ raise RuntimeError(
74
+ "Please use either `--steps-per-epoch-val` or `--samples-per-epoch-val`, but not both!"
75
+ )
76
+ if args.steps_per_epoch_val is None and args.steps_per_epoch is not None:
77
+ args.steps_per_epoch_val = round(
78
+ args.steps_per_epoch * (1 - args.train_val_split) / args.train_val_split
79
+ )
80
+ if args.steps_per_epoch_val is not None and args.steps_per_epoch_val < 0:
81
+ args.steps_per_epoch_val = None
82
+ return args
83
+
84
+ def to_filelist(args, mode="train"):
85
+ if mode == "train":
86
+ flist = args.data_train
87
+ elif mode == "val":
88
+ flist = args.data_val
89
+ else:
90
+ raise NotImplementedError("Invalid mode %s" % mode)
91
+
92
+ # keyword-based: 'a:/path/to/a b:/path/to/b'
93
+ file_dict = {}
94
+ for f in flist:
95
+ if ":" in f:
96
+ name, fp = f.split(":")
97
+ else:
98
+ name, fp = "_", f
99
+ files = glob.glob(fp)
100
+ if name in file_dict:
101
+ file_dict[name] += files
102
+ else:
103
+ file_dict[name] = files
104
+
105
+ # sort files
106
+ for name, files in file_dict.items():
107
+ file_dict[name] = sorted(files)
108
+
109
+ if args.local_rank is not None:
110
+ if mode == "train":
111
+ gpus_list, _ = set_gpus(args)
112
+ local_world_size = len(gpus_list) # int(os.environ['LOCAL_WORLD_SIZE'])
113
+ new_file_dict = {}
114
+ for name, files in file_dict.items():
115
+ new_files = files[args.local_rank :: local_world_size]
116
+ assert len(new_files) > 0
117
+ np.random.shuffle(new_files)
118
+ new_file_dict[name] = new_files
119
+ file_dict = new_file_dict
120
+ print(args.local_rank, len(file_dict["_"]))
121
+
122
+
123
+ filelist = sum(file_dict.values(), [])
124
+ assert len(filelist) == len(set(filelist))
125
+ return file_dict, filelist
126
+
127
+
128
+ def train_load(args):
129
+ """
130
+ Loads the training data.
131
+ :param args:
132
+ :return: train_loader, val_loader, data_config, train_inputs
133
+ """
134
+ train_file_dict, train_files = to_filelist(args, "train")
135
+ if args.data_val:
136
+ val_file_dict, val_files = to_filelist(args, "val")
137
+ train_range = val_range = (0, 1)
138
+ else:
139
+ val_file_dict, val_files = train_file_dict, train_files
140
+ train_range = (0, args.train_val_split)
141
+ val_range = (args.train_val_split, 1)
142
+
143
+
144
+
145
+ train_data = SimpleIterDataset(
146
+ train_file_dict,
147
+ args.data_config,
148
+ for_training=True,
149
+ extra_selection=None,
150
+ remake_weights=False,
151
+ load_range_and_fraction=(train_range, args.data_fraction),
152
+ file_fraction=args.file_fraction,
153
+ fetch_by_files=args.fetch_by_files,
154
+ fetch_step=args.fetch_step,
155
+ infinity_mode=args.steps_per_epoch is not None,
156
+ name="train" + ("" if args.local_rank is None else "_rank%d" % args.local_rank),
157
+ args_parse=args
158
+ )
159
+ val_data = SimpleIterDataset(
160
+ val_file_dict,
161
+ args.data_config,
162
+ for_training=True,
163
+ extra_selection=None,
164
+ load_range_and_fraction=(val_range, args.data_fraction),
165
+ file_fraction=args.file_fraction,
166
+ fetch_by_files=args.fetch_by_files,
167
+ fetch_step=args.fetch_step,
168
+ infinity_mode=args.steps_per_epoch_val is not None,
169
+ name="val" + ("" if args.local_rank is None else "_rank%d" % args.local_rank),
170
+ args_parse=args
171
+ )
172
+
173
+ collator_func = graph_batch_func
174
+ # train_data_arg = train_data
175
+ # val_data_arg = val_data
176
+ # if args.train_cap == 1:
177
+ # train_data_arg = [next(iter(train_data_arg))]
178
+ # if args.val_cap == 1:
179
+ # val_data_arg = [next(iter(val_data_arg))]
180
+ prefetch_factor = None
181
+ if args.num_workers > 0:
182
+ prefetch_factor = args.prefetch_factor
183
+ train_loader = DataLoader(
184
+ train_data,
185
+ batch_size=args.batch_size,
186
+ drop_last=True,
187
+ pin_memory=True,
188
+ num_workers=min(args.num_workers, int(len(train_files) * args.file_fraction)),
189
+ collate_fn=collator_func,
190
+ persistent_workers=False,
191
+ prefetch_factor=prefetch_factor
192
+ )
193
+ val_loader = DataLoader(
194
+ val_data,
195
+ batch_size=args.batch_size,
196
+ drop_last=True,
197
+ pin_memory=True,
198
+ collate_fn=collator_func,
199
+ num_workers=min(args.num_workers, int(len(val_files) * args.file_fraction)),
200
+ persistent_workers=args.num_workers > 0
201
+ and args.steps_per_epoch_val is not None,
202
+ prefetch_factor=prefetch_factor
203
+ )
204
+
205
+ data_config = 0 #train_data.config
206
+ train_input_names = 0 #train_data.config.input_names
207
+ train_label_names = 0 # train_data.config.label_names
208
+
209
+ return train_loader, val_loader, data_config, train_input_names
210
+
211
+
212
+ def test_load(args):
213
+ """
214
+ Loads the test data.
215
+ :param args:
216
+ :return: test_loaders, data_config
217
+ """
218
+ # keyword-based --data-test: 'a:/path/to/a b:/path/to/b'
219
+ # split --data-test: 'a%10:/path/to/a/*'
220
+ file_dict = {}
221
+ split_dict = {}
222
+ for f in args.data_test:
223
+ if ":" in f:
224
+ name, fp = f.split(":")
225
+ if "%" in name:
226
+ name, split = name.split("%")
227
+ split_dict[name] = int(split)
228
+ else:
229
+ name, fp = "", f
230
+ files = glob.glob(fp)
231
+ if name in file_dict:
232
+ file_dict[name] += files
233
+ else:
234
+ file_dict[name] = files
235
+
236
+ # sort files
237
+ for name, files in file_dict.items():
238
+ file_dict[name] = sorted(files)
239
+
240
+ # apply splitting
241
+ for name, split in split_dict.items():
242
+ files = file_dict.pop(name)
243
+ for i in range((len(files) + split - 1) // split):
244
+ file_dict[f"{name}_{i}"] = files[i * split : (i + 1) * split]
245
+
246
+ def get_test_loader(name):
247
+ filelist = file_dict[name]
248
+ num_workers = min(args.num_workers, len(filelist))
249
+ test_data = SimpleIterDataset(
250
+ {name: filelist},
251
+ args.data_config,
252
+ for_training=False,
253
+ extra_selection=None,
254
+ load_range_and_fraction=((0, 1), args.data_fraction),
255
+ fetch_by_files=True,
256
+ fetch_step=1,
257
+ name="test_" + name,
258
+ args_parse=args
259
+ )
260
+ test_loader = DataLoader(
261
+ test_data,
262
+ num_workers=num_workers,
263
+ batch_size=args.batch_size,
264
+ drop_last=False,
265
+ pin_memory=True,
266
+ collate_fn=graph_batch_func,
267
+ )
268
+ return test_loader
269
+
270
+ test_loaders = {
271
+ name: functools.partial(get_test_loader, name) for name in file_dict
272
+ }
273
+ #data_config = SimpleIterDataset({}, args.data_config, for_training=False).config
274
+ data_config = 0
275
+ return test_loaders, data_config
276
+
277
+
278
+ def count_parameters(model):
279
+ return sum(p.numel() for p in model.mod.parameters() if p.requires_grad)
280
+
281
+
tests/test_cpu_attention.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the CPU-compatible attention patch in src/inference.py."""
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def _cpu_sdpa_under_test(q, k, v, attn_mask=None):
8
+ """Standalone copy of _cpu_sdpa (the patched attention) for testing.
9
+
10
+ Mirrors the implementation in src.inference._patch_gatr_attention_for_cpu.
11
+ """
12
+ B, H, N, D = q.shape
13
+ scale = float(D) ** -0.5
14
+
15
+ q2 = q.reshape(B * H, N, D)
16
+ k2 = k.reshape(B * H, N, D)
17
+ v2 = v.reshape(B * H, N, D)
18
+
19
+ attn = torch.bmm(q2 * scale, k2.transpose(1, 2))
20
+
21
+ if attn_mask is not None:
22
+ attn = attn.masked_fill(~attn_mask.unsqueeze(0), float("-inf"))
23
+
24
+ attn = torch.softmax(attn, dim=-1)
25
+ attn = attn.nan_to_num(0.0)
26
+
27
+ out = torch.bmm(attn, v2)
28
+ return out.reshape(B, H, N, D)
29
+
30
+
31
+ def test_cpu_sdpa_matches_reference():
32
+ """The CPU SDPA must agree with PyTorch's reference implementation."""
33
+ torch.manual_seed(42)
34
+ B, H, N, D = 2, 4, 16, 32
35
+ q = torch.randn(B, H, N, D)
36
+ k = torch.randn(B, H, N, D)
37
+ v = torch.randn(B, H, N, D)
38
+
39
+ out_ours = _cpu_sdpa_under_test(q, k, v)
40
+ # PyTorch reference (no mask)
41
+ out_ref = F.scaled_dot_product_attention(q, k, v)
42
+
43
+ assert out_ours.shape == (B, H, N, D)
44
+ assert torch.allclose(out_ours, out_ref, atol=1e-5), (
45
+ f"Max diff: {(out_ours - out_ref).abs().max().item()}"
46
+ )
47
+
48
+
49
+ def test_cpu_sdpa_output_shape():
50
+ """Output shape must be [B, H, N, D], matching the input convention."""
51
+ B, H, N, D = 1, 8, 64, 16
52
+ q = torch.randn(B, H, N, D)
53
+ k = torch.randn(B, H, N, D)
54
+ v = torch.randn(B, H, N, D)
55
+ out = _cpu_sdpa_under_test(q, k, v)
56
+ assert out.shape == (B, H, N, D)
57
+
58
+
59
+ def test_cpu_sdpa_single_head():
60
+ """Single-head attention must work correctly."""
61
+ torch.manual_seed(0)
62
+ B, H, N, D = 1, 1, 10, 8
63
+ q = torch.randn(B, H, N, D)
64
+ k = torch.randn(B, H, N, D)
65
+ v = torch.randn(B, H, N, D)
66
+
67
+ out_ours = _cpu_sdpa_under_test(q, k, v)
68
+ out_ref = F.scaled_dot_product_attention(q, k, v)
69
+
70
+ assert torch.allclose(out_ours, out_ref, atol=1e-5)
71
+
72
+
73
+ def test_cpu_sdpa_asymmetric_heads_items():
74
+ """Ensure heads and items dimensions are not confused.
75
+
76
+ When H != N, swapping them would change the tensor layout and
77
+ produce different (wrong) results.
78
+ """
79
+ torch.manual_seed(123)
80
+ B, H, N, D = 1, 3, 7, 16 # H != N
81
+ q = torch.randn(B, H, N, D)
82
+ k = torch.randn(B, H, N, D)
83
+ v = torch.randn(B, H, N, D)
84
+
85
+ out_ours = _cpu_sdpa_under_test(q, k, v)
86
+ out_ref = F.scaled_dot_product_attention(q, k, v)
87
+
88
+ assert out_ours.shape == (B, H, N, D)
89
+ assert torch.allclose(out_ours, out_ref, atol=1e-5), (
90
+ f"Max diff: {(out_ours - out_ref).abs().max().item()}"
91
+ )
92
+
93
+
94
+ if __name__ == "__main__":
95
+ test_cpu_sdpa_matches_reference()
96
+ test_cpu_sdpa_output_shape()
97
+ test_cpu_sdpa_single_head()
98
+ test_cpu_sdpa_asymmetric_heads_items()
99
+ print("All tests passed.")
tests/test_csv_priority.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests that CSV data takes priority over parquet when both are available.
2
+
3
+ This validates the fix for the issue where loading an event from parquet and
4
+ then modifying the CSV text fields (e.g. removing tracks) was ignored because
5
+ the code always re-loaded from the parquet file.
6
+ """
7
+
8
+ import os
9
+ import ast
10
+ import textwrap
11
+
12
+
13
+ def _extract_source_priority_logic():
14
+ """Extract and verify the input-source priority logic from app.py.
15
+
16
+ Reads the ``run_inference_ui`` function source and checks that CSV
17
+ is tested *before* parquet, so that user edits to the CSV text
18
+ fields are respected even when a parquet file path is present.
19
+ """
20
+ app_path = os.path.join(os.path.dirname(__file__), "..", "app.py")
21
+ with open(app_path) as f:
22
+ source = f.read()
23
+ return source
24
+
25
+
26
+ def test_csv_checked_before_parquet():
27
+ """In run_inference_ui, the ``if use_csv`` branch must come before
28
+ ``use_parquet`` so that CSV edits are not silently ignored."""
29
+ source = _extract_source_priority_logic()
30
+
31
+ # Find positions of the key branching statements
32
+ idx_csv = source.find("if use_csv:")
33
+ idx_parquet_elif = source.find("elif use_parquet:")
34
+ idx_parquet_if = source.find("if use_parquet:")
35
+
36
+ # "if use_csv:" must exist
37
+ assert idx_csv != -1, "Could not find 'if use_csv:' in app.py"
38
+
39
+ # "elif use_parquet:" must exist (parquet is the fallback)
40
+ assert idx_parquet_elif != -1, (
41
+ "Could not find 'elif use_parquet:' in app.py — parquet should be "
42
+ "a fallback after CSV"
43
+ )
44
+
45
+ # CSV check must come before the parquet fallback
46
+ assert idx_csv < idx_parquet_elif, (
47
+ "'if use_csv:' must appear before 'elif use_parquet:' so that "
48
+ "user CSV edits take priority over re-reading the parquet file"
49
+ )
50
+
51
+ # There should NOT be a standalone "if use_parquet:" that would take
52
+ # priority over CSV (the old buggy pattern)
53
+ if idx_parquet_if != -1:
54
+ # The only occurrence should be inside the guard for empty input
55
+ # (not use_parquet and not use_csv). A standalone "if use_parquet:"
56
+ # that dispatches to load_event_from_parquet before checking CSV is
57
+ # the bug we fixed.
58
+ # Make sure it's not followed by load_event_from_parquet before
59
+ # "if use_csv:" appears
60
+ assert idx_parquet_if > idx_csv or "load_event_from_parquet" not in source[idx_parquet_if:idx_csv], (
61
+ "Found 'if use_parquet:' with load_event_from_parquet before "
62
+ "'if use_csv:' — this is the bug where parquet takes priority "
63
+ "over CSV edits"
64
+ )
65
+
66
+
67
+ def test_parse_csv_event_logic():
68
+ """_parse_csv_event should correctly build event dicts from CSV text.
69
+
70
+ We inline the same parsing logic used by app.py to avoid importing
71
+ the module (which requires heavy dependencies like gradio).
72
+ """
73
+ import io
74
+ import numpy as np
75
+ import pandas as pd
76
+
77
+ def _read(text, min_cols=1):
78
+ if not text or not text.strip():
79
+ return np.zeros((0, min_cols), dtype=np.float64)
80
+ df = pd.read_csv(io.StringIO(text), header=None)
81
+ return df.values.astype(np.float64)
82
+
83
+ def _parse_csv_event(csv_hits, csv_tracks, csv_particles, csv_pandora=""):
84
+ hits_arr = _read(csv_hits, 11)
85
+ tracks_arr = _read(csv_tracks, 25)
86
+ particles_arr = _read(csv_particles, 18)
87
+ pandora_arr = _read(csv_pandora, 9)
88
+ if tracks_arr.shape[1] < 25 and tracks_arr.shape[0] > 0:
89
+ pad = np.zeros((tracks_arr.shape[0], 25 - tracks_arr.shape[1]))
90
+ tracks_arr = np.concatenate([tracks_arr, pad], axis=1)
91
+ ygen_hit = np.full(len(hits_arr), -1, dtype=np.int64)
92
+ ygen_track = np.full(len(tracks_arr), -1, dtype=np.int64)
93
+ return {
94
+ "X_hit": hits_arr,
95
+ "X_track": tracks_arr,
96
+ "X_gen": particles_arr,
97
+ "X_pandora": pandora_arr,
98
+ "ygen_hit": ygen_hit,
99
+ "ygen_track": ygen_track,
100
+ }
101
+
102
+ # Basic parse
103
+ csv_hits = "0,0,0,0,0,1.23,1800.5,200.3,100.1,0,1"
104
+ event = _parse_csv_event(csv_hits, "", "", "")
105
+ assert event["X_hit"].shape == (1, 11)
106
+ assert event["X_track"].shape == (0, 25)
107
+ assert np.isclose(event["X_hit"][0, 5], 1.23)
108
+
109
+ # Empty tracks after removing them
110
+ event2 = _parse_csv_event(csv_hits, "", "", "")
111
+ assert event2["X_track"].shape[0] == 0
112
+
113
+ # Two tracks vs one track
114
+ csv_tracks_two = (
115
+ "1,0,0,0,0,5.0,3.0,2.0,3.3,0,0,0,1800.0,150.0,90.0,12.5,8,0,0,0,0,0,2.9,1.9,3.2\n"
116
+ "1,0,0,0,0,3.0,1.0,1.5,2.1,0,0,0,1700.0,100.0,80.0,10.0,6,0,0,0,0,0,0.9,1.4,2.0"
117
+ )
118
+ csv_tracks_one = (
119
+ "1,0,0,0,0,5.0,3.0,2.0,3.3,0,0,0,1800.0,150.0,90.0,12.5,8,0,0,0,0,0,2.9,1.9,3.2"
120
+ )
121
+ event_two = _parse_csv_event(csv_hits, csv_tracks_two, "", "")
122
+ event_one = _parse_csv_event(csv_hits, csv_tracks_one, "", "")
123
+ assert event_two["X_track"].shape[0] == 2
124
+ assert event_one["X_track"].shape[0] == 1
125
+
126
+
127
+ def test_input_source_decision_logic():
128
+ """Simulate the decision logic from run_inference_ui and verify that
129
+ CSV is used even when a parquet path is present."""
130
+
131
+ def decide_source(parquet_path, csv_hits):
132
+ """Mirrors the decision logic in run_inference_ui."""
133
+ use_parquet = parquet_path and os.path.isfile(parquet_path)
134
+ use_csv = bool(csv_hits and csv_hits.strip())
135
+
136
+ if use_csv:
137
+ return "csv"
138
+ elif use_parquet:
139
+ return "parquet"
140
+ else:
141
+ return "none"
142
+
143
+ # CSV present + parquet path present → should use CSV
144
+ # (use this script as a stand-in for an existing file)
145
+ existing_file = os.path.abspath(__file__)
146
+ assert decide_source(existing_file, "some,csv,data") == "csv"
147
+
148
+ # CSV present + no parquet → should use CSV
149
+ assert decide_source("", "some,csv,data") == "csv"
150
+
151
+ # CSV empty + parquet present → should use parquet
152
+ assert decide_source(existing_file, "") == "parquet"
153
+
154
+ # Both empty → none
155
+ assert decide_source("", "") == "none"
156
+
157
+
158
+ if __name__ == "__main__":
159
+ test_csv_checked_before_parquet()
160
+ test_parse_csv_event_logic()
161
+ test_input_source_decision_logic()
162
+ print("All tests passed.")
tests/test_energy_correction_no_matches.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests that energy correction runs even when no MC-truth showers are matched.
2
+
3
+ The bug: ``_run_energy_correction`` returned early (``if not graphs_matched:
4
+ return particles_df``) whenever no predicted cluster could be matched to a
5
+ true particle. In pure inference mode (no MC truth) *all* clusters are
6
+ "fakes" and ``graphs_matched`` is always empty, so the correction was never
7
+ applied and the output table only contained the basic ``energy_sum_hits`` /
8
+ ``p_track`` columns.
9
+
10
+ The fix: only bail out when *both* ``graphs_matched`` **and** ``graphs_fakes``
11
+ are empty (i.e. there are literally no clusters to correct).
12
+ """
13
+
14
+ import ast
15
+ import os
16
+
17
+
18
+ def _get_function_source(path, func_name):
19
+ """Return the source of a top-level function from *path*."""
20
+ with open(path) as f:
21
+ source = f.read()
22
+ tree = ast.parse(source)
23
+ lines = source.splitlines(keepends=True)
24
+ for node in tree.body:
25
+ if isinstance(node, ast.FunctionDef) and node.name == func_name:
26
+ return "".join(lines[node.lineno - 1 : node.end_lineno])
27
+ raise ValueError(f"{func_name} not found in {path}")
28
+
29
+
30
+ INFERENCE_PATH = os.path.join(
31
+ os.path.dirname(__file__), "..", "src", "inference.py"
32
+ )
33
+
34
+
35
+ def test_early_return_requires_both_empty():
36
+ """The early return must check both graphs_matched *and* graphs_fakes.
37
+
38
+ The old (buggy) guard was:
39
+ if not graphs_matched:
40
+ return particles_df
41
+
42
+ The fixed guard must be:
43
+ if not graphs_matched and not graphs_fakes:
44
+ return particles_df
45
+ """
46
+ src = _get_function_source(INFERENCE_PATH, "_run_energy_correction")
47
+
48
+ # The buggy single-condition early return must NOT appear
49
+ assert "if not graphs_matched:\n return particles_df" not in src, (
50
+ "Found the old single-condition early return 'if not graphs_matched'; "
51
+ "energy correction would be skipped whenever no MC-truth matches exist."
52
+ )
53
+
54
+ # The correct two-condition guard must be present
55
+ assert "if not graphs_matched and not graphs_fakes:" in src, (
56
+ "Expected 'if not graphs_matched and not graphs_fakes:' in "
57
+ "_run_energy_correction but did not find it."
58
+ )
59
+
60
+
61
+ def test_true_energies_t_not_called_with_cat_on_empty():
62
+ """``torch.cat(true_energies, dim=0)`` must not appear unconditionally.
63
+
64
+ When ``graphs_matched`` is empty, ``true_energies`` is an empty list and
65
+ ``torch.cat([], dim=0)`` raises a RuntimeError. The fixed code removes
66
+ this line entirely (the variable was unused anyway).
67
+ """
68
+ src = _get_function_source(INFERENCE_PATH, "_run_energy_correction")
69
+
70
+ # Either the assignment is gone, or it is guarded
71
+ if "true_energies_t = torch.cat(true_energies" in src:
72
+ # If it still exists it must be guarded by an if-statement
73
+ lines = src.splitlines()
74
+ for i, line in enumerate(lines):
75
+ if "true_energies_t = torch.cat(true_energies" in line:
76
+ # Check that a guard exists somewhere before this line
77
+ guard_present = any(
78
+ "if true_energies" in lines[j] or "if graphs_matched" in lines[j]
79
+ for j in range(max(0, i - 5), i)
80
+ )
81
+ assert guard_present, (
82
+ f"Line {i}: unguarded 'torch.cat(true_energies)' would "
83
+ "raise RuntimeError on empty list when no showers match."
84
+ )
85
+
86
+
87
+ if __name__ == "__main__":
88
+ test_early_return_requires_both_empty()
89
+ test_true_energies_t_not_called_with_cat_on_empty()
90
+ print("All tests passed.")
tests/test_pfo_links.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the hit → Pandora cluster mapping (PFO links) field.
2
+
3
+ Validates that:
4
+ 1. _parse_csv_event correctly parses the csv_pfo_links parameter.
5
+ 2. PFO links are gracefully handled when CSV is modified (partial matches).
6
+ 3. The _load_event_into_csv function includes PFO links output.
7
+ 4. The run_inference_ui function accepts the csv_pfo_links parameter.
8
+ """
9
+
10
+ import os
11
+ import io
12
+ import numpy as np
13
+ import pandas as pd
14
+
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Inline the parsing logic to avoid importing app.py (heavy dependencies)
18
+ # ---------------------------------------------------------------------------
19
+
20
+ def _parse_csv_event(csv_hits, csv_tracks, csv_particles, csv_pandora="", csv_pfo_links=""):
21
+ """Mirror of the _parse_csv_event logic from app.py."""
22
+
23
+ def _read(text, min_cols=1):
24
+ if not text or not text.strip():
25
+ return np.zeros((0, min_cols), dtype=np.float64)
26
+ df = pd.read_csv(io.StringIO(text), header=None)
27
+ return df.values.astype(np.float64)
28
+
29
+ hits_arr = _read(csv_hits, 11)
30
+ tracks_arr = _read(csv_tracks, 25)
31
+ particles_arr = _read(csv_particles, 18)
32
+ pandora_arr = _read(csv_pandora, 9)
33
+ if tracks_arr.shape[1] < 25 and tracks_arr.shape[0] > 0:
34
+ pad = np.zeros((tracks_arr.shape[0], 25 - tracks_arr.shape[1]))
35
+ tracks_arr = np.concatenate([tracks_arr, pad], axis=1)
36
+ ygen_hit = np.full(len(hits_arr), -1, dtype=np.int64)
37
+ ygen_track = np.full(len(tracks_arr), -1, dtype=np.int64)
38
+
39
+ # Parse PFO link arrays
40
+ pfo_calohit = np.array([], dtype=np.int64)
41
+ pfo_track = np.array([], dtype=np.int64)
42
+ if csv_pfo_links and csv_pfo_links.strip():
43
+ lines = csv_pfo_links.strip().split("\n")
44
+ if len(lines) >= 1 and lines[0].strip():
45
+ pfo_calohit = np.array(
46
+ [int(v) for v in lines[0].strip().split(",")], dtype=np.int64
47
+ )
48
+ if len(lines) >= 2 and lines[1].strip():
49
+ pfo_track = np.array(
50
+ [int(v) for v in lines[1].strip().split(",")], dtype=np.int64
51
+ )
52
+
53
+ return {
54
+ "X_hit": hits_arr,
55
+ "X_track": tracks_arr,
56
+ "X_gen": particles_arr,
57
+ "X_pandora": pandora_arr,
58
+ "ygen_hit": ygen_hit,
59
+ "ygen_track": ygen_track,
60
+ "pfo_calohit": pfo_calohit,
61
+ "pfo_track": pfo_track,
62
+ }
63
+
64
+
65
+ # ---------------------------------------------------------------------------
66
+ # Tests
67
+ # ---------------------------------------------------------------------------
68
+
69
+ def test_parse_pfo_links_basic():
70
+ """PFO links should be correctly parsed from csv_pfo_links."""
71
+ csv_hits = "0,0,0,0,0,1.23,1800.5,200.3,100.1,0,1\n0,0,0,0,0,0.45,1900.2,-50.1,300.7,0,2"
72
+ csv_pfo_links = "3,5\n7"
73
+
74
+ event = _parse_csv_event(csv_hits, "", "", "", csv_pfo_links)
75
+
76
+ assert "pfo_calohit" in event
77
+ assert "pfo_track" in event
78
+ np.testing.assert_array_equal(event["pfo_calohit"], [3, 5])
79
+ np.testing.assert_array_equal(event["pfo_track"], [7])
80
+
81
+
82
+ def test_parse_pfo_links_empty():
83
+ """Empty csv_pfo_links should produce empty arrays."""
84
+ csv_hits = "0,0,0,0,0,1.23,1800.5,200.3,100.1,0,1"
85
+ event = _parse_csv_event(csv_hits, "", "", "", "")
86
+
87
+ assert len(event["pfo_calohit"]) == 0
88
+ assert len(event["pfo_track"]) == 0
89
+
90
+
91
+ def test_parse_pfo_links_calohit_only():
92
+ """Only calohit line provided (no track line)."""
93
+ csv_pfo_links = "1,2,-1,3"
94
+ event = _parse_csv_event("0,0,0,0,0,1.0,1.0,1.0,1.0,0,1", "", "", "", csv_pfo_links)
95
+
96
+ np.testing.assert_array_equal(event["pfo_calohit"], [1, 2, -1, 3])
97
+ assert len(event["pfo_track"]) == 0
98
+
99
+
100
+ def test_parse_pfo_links_with_negatives():
101
+ """PFO links should correctly handle -1 values (unassigned hits)."""
102
+ csv_pfo_links = "3,-1,5,-1\n-1,2"
103
+ event = _parse_csv_event("", "", "", "", csv_pfo_links)
104
+
105
+ np.testing.assert_array_equal(event["pfo_calohit"], [3, -1, 5, -1])
106
+ np.testing.assert_array_equal(event["pfo_track"], [-1, 2])
107
+
108
+
109
+ def test_pandora_cluster_partial_match():
110
+ """When CSV is modified (fewer hits than PFO links), use min of lengths."""
111
+ # Simulate the assignment logic from inference.py
112
+ n_calo = 3 # only 3 hits now
113
+ n_tracks = 1 # only 1 track now
114
+ n_hits = n_calo + n_tracks
115
+
116
+ pfo_calohit = np.array([0, 1, 2, 3, 4], dtype=np.int64) # originally 5 hits
117
+ pfo_track = np.array([5, 6], dtype=np.int64) # originally 2 tracks
118
+
119
+ pandora_cluster_ids = np.full(n_hits, -1, dtype=np.int64)
120
+ if len(pfo_calohit) > 0:
121
+ n_assign = min(len(pfo_calohit), n_calo)
122
+ pandora_cluster_ids[:n_assign] = pfo_calohit[:n_assign]
123
+ if n_tracks > 0 and len(pfo_track) > 0:
124
+ n_assign = min(len(pfo_track), n_tracks)
125
+ pandora_cluster_ids[n_calo:n_calo + n_assign] = pfo_track[:n_assign]
126
+
127
+ # First 3 calo hits should get their PFO IDs, 4th hit (track) gets first track PFO
128
+ np.testing.assert_array_equal(pandora_cluster_ids, [0, 1, 2, 5])
129
+
130
+
131
+ def test_pandora_cluster_no_links():
132
+ """When no PFO links are available, all pandora_cluster_ids should be -1."""
133
+ n_hits = 5
134
+ n_calo = 3
135
+ n_tracks = 2
136
+
137
+ pfo_calohit = np.array([], dtype=np.int64)
138
+ pfo_track = np.array([], dtype=np.int64)
139
+
140
+ pandora_cluster_ids = np.full(n_hits, -1, dtype=np.int64)
141
+ if len(pfo_calohit) > 0:
142
+ n_assign = min(len(pfo_calohit), n_calo)
143
+ pandora_cluster_ids[:n_assign] = pfo_calohit[:n_assign]
144
+ if n_tracks > 0 and len(pfo_track) > 0:
145
+ n_assign = min(len(pfo_track), n_tracks)
146
+ pandora_cluster_ids[n_calo:n_calo + n_assign] = pfo_track[:n_assign]
147
+
148
+ np.testing.assert_array_equal(pandora_cluster_ids, [-1, -1, -1, -1, -1])
149
+
150
+
151
+ def test_pandora_cluster_more_hits_than_links():
152
+ """When more hits exist than PFO links, extra hits get -1."""
153
+ n_calo = 5
154
+ n_tracks = 2
155
+ n_hits = n_calo + n_tracks
156
+
157
+ pfo_calohit = np.array([1, 2], dtype=np.int64) # only 2 links for 5 hits
158
+ pfo_track = np.array([3], dtype=np.int64) # only 1 link for 2 tracks
159
+
160
+ pandora_cluster_ids = np.full(n_hits, -1, dtype=np.int64)
161
+ if len(pfo_calohit) > 0:
162
+ n_assign = min(len(pfo_calohit), n_calo)
163
+ pandora_cluster_ids[:n_assign] = pfo_calohit[:n_assign]
164
+ if n_tracks > 0 and len(pfo_track) > 0:
165
+ n_assign = min(len(pfo_track), n_tracks)
166
+ pandora_cluster_ids[n_calo:n_calo + n_assign] = pfo_track[:n_assign]
167
+
168
+ np.testing.assert_array_equal(pandora_cluster_ids, [1, 2, -1, -1, -1, 3, -1])
169
+
170
+
171
+ def test_app_source_has_csv_pfo_links_field():
172
+ """app.py should have the csv_pfo_links text field wired up."""
173
+ app_path = os.path.join(os.path.dirname(__file__), "..", "app.py")
174
+ with open(app_path) as f:
175
+ source = f.read()
176
+
177
+ assert "csv_pfo_links" in source, "app.py should reference csv_pfo_links"
178
+ assert "Hit → Pandora Cluster links" in source, (
179
+ "app.py should have the PFO links text field label"
180
+ )
181
+
182
+
183
+ def test_run_inference_ui_accepts_pfo_links():
184
+ """run_inference_ui should accept csv_pfo_links as a parameter."""
185
+ import ast
186
+ app_path = os.path.join(os.path.dirname(__file__), "..", "app.py")
187
+ with open(app_path) as f:
188
+ tree = ast.parse(f.read())
189
+
190
+ for node in ast.walk(tree):
191
+ if isinstance(node, ast.FunctionDef) and node.name == "run_inference_ui":
192
+ arg_names = [arg.arg for arg in node.args.args]
193
+ assert "csv_pfo_links" in arg_names, (
194
+ "run_inference_ui should accept csv_pfo_links parameter"
195
+ )
196
+ return
197
+ raise AssertionError("Could not find run_inference_ui function in app.py")
198
+
199
+
200
+ def test_load_event_returns_pfo_links():
201
+ """_load_event_into_csv error path should return 6 values (including PFO links)."""
202
+ import ast
203
+ app_path = os.path.join(os.path.dirname(__file__), "..", "app.py")
204
+ with open(app_path) as f:
205
+ tree = ast.parse(f.read())
206
+
207
+ for node in ast.walk(tree):
208
+ if isinstance(node, ast.FunctionDef) and node.name == "_load_event_into_csv":
209
+ # Check return statements in the function body
210
+ for child in ast.walk(node):
211
+ if isinstance(child, ast.Return) and isinstance(child.value, ast.Tuple):
212
+ n_elts = len(child.value.elts)
213
+ assert n_elts == 6, (
214
+ f"_load_event_into_csv should return 6 values, got {n_elts}"
215
+ )
216
+ return
217
+ raise AssertionError("Could not find _load_event_into_csv function in app.py")
218
+
219
+
220
+ if __name__ == "__main__":
221
+ test_parse_pfo_links_basic()
222
+ test_parse_pfo_links_empty()
223
+ test_parse_pfo_links_calohit_only()
224
+ test_parse_pfo_links_with_negatives()
225
+ test_pandora_cluster_partial_match()
226
+ test_pandora_cluster_no_links()
227
+ test_pandora_cluster_more_hits_than_links()
228
+ test_app_source_has_csv_pfo_links_field()
229
+ test_run_inference_ui_accepts_pfo_links()
230
+ test_load_event_returns_pfo_links()
231
+ print("All PFO links tests passed.")