Spaces:
Sleeping
Sleeping
github-actions[bot] commited on
Commit ·
cc0720f
0
Parent(s):
Sync from GitHub f6dbbfb
Browse files- .github/workflows/sync-to-hf-space.yml +41 -0
- .gitignore +49 -0
- Dockerfile +15 -0
- README.md +79 -0
- app.py +772 -0
- config_files/config_hits_track_v4.yaml +146 -0
- scripts/evaluation.sh +26 -0
- scripts/train_clustering.sh +20 -0
- scripts/train_energy_pid.sh +24 -0
- src/data/config.py +218 -0
- src/data/fileio.py +101 -0
- src/data/preprocess.py +253 -0
- src/data/tools.py +191 -0
- src/dataset/dataclasses.py +126 -0
- src/dataset/dataset.py +287 -0
- src/dataset/functions_data.py +26 -0
- src/dataset/functions_graph.py +105 -0
- src/dataset/functions_particles.py +122 -0
- src/inference.py +735 -0
- src/layers/clustering.py +99 -0
- src/layers/inference_oc.py +251 -0
- src/layers/object_cond.py +609 -0
- src/layers/regression/loss_regression.py +59 -0
- src/layers/shower_dataframe.py +441 -0
- src/layers/shower_matching.py +127 -0
- src/layers/tools_for_regression.py +131 -0
- src/layers/utils_training.py +166 -0
- src/models/E_correction_module.py +43 -0
- src/models/Gatr_pf_e_noise.py +332 -0
- src/models/energy_correction_NN.py +299 -0
- src/models/energy_correction_charged.py +116 -0
- src/models/energy_correction_neutral.py +157 -0
- src/models/wrapper/example_mode_gatr_noise.py +21 -0
- src/train_lightning1.py +128 -0
- src/utils/callbacks.py +30 -0
- src/utils/import_tools.py +8 -0
- src/utils/inference/pandas_helpers.py +36 -0
- src/utils/load_pretrained_models.py +32 -0
- src/utils/logger_wandb.py +33 -0
- src/utils/parser_args.py +246 -0
- src/utils/pid_conversion.py +7 -0
- src/utils/post_clustering_features.py +82 -0
- src/utils/train_utils.py +281 -0
- tests/test_cpu_attention.py +99 -0
- tests/test_csv_priority.py +162 -0
- tests/test_energy_correction_no_matches.py +90 -0
- tests/test_pfo_links.py +231 -0
.github/workflows/sync-to-hf-space.yml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Sync to Hugging Face Space
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
|
| 8 |
+
permissions:
|
| 9 |
+
contents: read
|
| 10 |
+
|
| 11 |
+
jobs:
|
| 12 |
+
sync-to-hf:
|
| 13 |
+
runs-on: ubuntu-latest
|
| 14 |
+
steps:
|
| 15 |
+
- name: Checkout repo (no history)
|
| 16 |
+
uses: actions/checkout@v4
|
| 17 |
+
with:
|
| 18 |
+
fetch-depth: 1
|
| 19 |
+
lfs: false
|
| 20 |
+
|
| 21 |
+
- name: Push to Hugging Face Space
|
| 22 |
+
env:
|
| 23 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 24 |
+
run: |
|
| 25 |
+
# Configure git
|
| 26 |
+
git config --global user.email "github-actions[bot]@users.noreply.github.com"
|
| 27 |
+
git config --global user.name "github-actions[bot]"
|
| 28 |
+
|
| 29 |
+
# Use a credential helper to avoid embedding the token in the URL
|
| 30 |
+
git config --global credential.helper store
|
| 31 |
+
printf 'https://user:%s@huggingface.co\n' "$HF_TOKEN" > ~/.git-credentials
|
| 32 |
+
|
| 33 |
+
# Create a fresh repo with a single commit (no history)
|
| 34 |
+
cd $GITHUB_WORKSPACE
|
| 35 |
+
rm -rf .git
|
| 36 |
+
git init --initial-branch main
|
| 37 |
+
git add .
|
| 38 |
+
git commit -m "Sync from GitHub ${GITHUB_SHA::7}"
|
| 39 |
+
|
| 40 |
+
# Force-push the single commit to HF Space
|
| 41 |
+
git push --force https://huggingface.co/spaces/gregorkrzmanc/HitPF_demo main
|
.gitignore
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
*.egg-info/
|
| 7 |
+
dist/
|
| 8 |
+
build/
|
| 9 |
+
.eggs/
|
| 10 |
+
|
| 11 |
+
# Jupyter
|
| 12 |
+
.ipynb_checkpoints/
|
| 13 |
+
*.ipynb
|
| 14 |
+
|
| 15 |
+
# Weights & Biases
|
| 16 |
+
wandb/
|
| 17 |
+
|
| 18 |
+
# Model checkpoints and outputs
|
| 19 |
+
*.pt
|
| 20 |
+
*.pth
|
| 21 |
+
showers_df_evaluation/
|
| 22 |
+
|
| 23 |
+
# Data files
|
| 24 |
+
*.root
|
| 25 |
+
*.h5
|
| 26 |
+
*.hdf5
|
| 27 |
+
*.pkl
|
| 28 |
+
*.pickle
|
| 29 |
+
*.npy
|
| 30 |
+
*.npz
|
| 31 |
+
|
| 32 |
+
# Demo files are downloaded at runtime from Hugging Face Hub
|
| 33 |
+
model_clustering.ckpt
|
| 34 |
+
model_e_pid.ckpt
|
| 35 |
+
test_data.parquet
|
| 36 |
+
|
| 37 |
+
# Logs
|
| 38 |
+
*.log
|
| 39 |
+
logs/
|
| 40 |
+
|
| 41 |
+
# Editors
|
| 42 |
+
.vscode/
|
| 43 |
+
.idea/
|
| 44 |
+
*.swp
|
| 45 |
+
*.swo
|
| 46 |
+
*~
|
| 47 |
+
|
| 48 |
+
# OS
|
| 49 |
+
.DS_Store
|
Dockerfile
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM dologarcia/gatr:v9
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN pip install --no-cache-dir \
|
| 6 |
+
densitypeakclustering \
|
| 7 |
+
lightning-utilities \
|
| 8 |
+
torchmetrics \
|
| 9 |
+
gradio \
|
| 10 |
+
plotly
|
| 11 |
+
|
| 12 |
+
COPY . .
|
| 13 |
+
EXPOSE 7860
|
| 14 |
+
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
| 15 |
+
CMD ["python", "app.py"]
|
README.md
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: HitPF
|
| 3 |
+
emoji: ⚛️
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_file: app.py
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# HitPF
|
| 12 |
+
|
| 13 |
+
**HitPF** is a GATr-based particle-flow reconstruction model for the CLD detector at the FCC-ee.
|
| 14 |
+
It performs two sequential tasks:
|
| 15 |
+
|
| 16 |
+
1. **Clustering** — groups calorimeter hits and tracks into particle-flow objects using an object-condensation loss.
|
| 17 |
+
2. **Property regression** — regresses a correction factor for each reconstructed cluster using a GNN-based model and a PID class
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## Dependencies
|
| 22 |
+
|
| 23 |
+
The code can be used with this container:
|
| 24 |
+
```docker://dologarcia/gatr:v9```
|
| 25 |
+
|
| 26 |
+
For the live demo, gradio and plotly also need to be installed:
|
| 27 |
+
```
|
| 28 |
+
pip install gradio plotly
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
## Dataset
|
| 34 |
+
|
| 35 |
+
Input data is stored as `.parquet` files, each file stores 100 events. A sample of the dataset in ML-ready format can be found at [1](https://zenodo.org/records/18749298). The full dataset is hosted on CERN's EOS space.
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+
## Training
|
| 41 |
+
|
| 42 |
+
### Step 1 — Clustering
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
bash scripts/train_clustering.sh
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
### Step 2 — Energy correction
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
bash scripts/train_energy_pid.sh
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
### Validation
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
bash scripts/evaluation.sh
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
---
|
| 63 |
+
### Live demo (work in progress)
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
python -m app
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## Citation
|
| 70 |
+
|
| 71 |
+
If you use this code, please cite:
|
| 72 |
+
|
| 73 |
+
```bibtex
|
| 74 |
+
@software{hitpf2026,
|
| 75 |
+
title = {End-to-end event reconstruction for precision physics at future colliders code},
|
| 76 |
+
year = {2026},
|
| 77 |
+
url = {https://github.com/mgarciam/HitPF}
|
| 78 |
+
}
|
| 79 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Gradio UI for single-event MLPF inference.
|
| 4 |
+
|
| 5 |
+
Launch with:
|
| 6 |
+
python app.py [--device cpu]
|
| 7 |
+
|
| 8 |
+
The UI lets you:
|
| 9 |
+
1. Load an event from a parquet file (pick file + event index), **or**
|
| 10 |
+
paste hit / track / particle data in CSV format.
|
| 11 |
+
2. (Optionally) load pre-trained model checkpoints.
|
| 12 |
+
3. Run inference → view predicted particles and the hit→cluster mapping.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import os
|
| 17 |
+
import shutil
|
| 18 |
+
import traceback
|
| 19 |
+
|
| 20 |
+
import gradio as gr
|
| 21 |
+
import pandas as pd
|
| 22 |
+
import numpy as np
|
| 23 |
+
import plotly.graph_objects as go
|
| 24 |
+
from huggingface_hub import hf_hub_download
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Auto-download demo files from Hugging Face Hub if they are not present
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
_HF_REPO_ID = "gregorkrzmanc/hitpf_demo_files"
|
| 31 |
+
_DEMO_FILES = [
|
| 32 |
+
"model_clustering.ckpt",
|
| 33 |
+
"model_e_pid.ckpt",
|
| 34 |
+
"test_data.parquet",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _ensure_demo_files(dest_dir: str = ".") -> None:
|
| 39 |
+
"""Download demo files from Hugging Face Hub if they don't already exist."""
|
| 40 |
+
for fname in _DEMO_FILES:
|
| 41 |
+
dest = os.path.join(dest_dir, fname)
|
| 42 |
+
if not os.path.isfile(dest):
|
| 43 |
+
try:
|
| 44 |
+
print(f"Downloading {fname} from HF Hub ({_HF_REPO_ID}) …")
|
| 45 |
+
downloaded = hf_hub_download(
|
| 46 |
+
repo_id=_HF_REPO_ID,
|
| 47 |
+
filename=fname,
|
| 48 |
+
repo_type="dataset",
|
| 49 |
+
)
|
| 50 |
+
shutil.copy(downloaded, dest)
|
| 51 |
+
print(f" → saved to {dest}")
|
| 52 |
+
except Exception as exc:
|
| 53 |
+
print(f" ⚠️ Could not download {fname}: {exc}")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
_ensure_demo_files()
|
| 57 |
+
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
# Global state – filled lazily
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
_MODEL = None
|
| 62 |
+
_ARGS = None
|
| 63 |
+
_DEVICE = "cpu"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _set_device(device: str):
|
| 67 |
+
global _DEVICE
|
| 68 |
+
_DEVICE = device
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
# Model loading
|
| 73 |
+
# ---------------------------------------------------------------------------
|
| 74 |
+
|
| 75 |
+
def load_model_ui(clustering_ckpt: str, energy_pid_ckpt: str, device: str):
|
| 76 |
+
"""Load model from checkpoint paths (called by the UI button)."""
|
| 77 |
+
global _MODEL, _ARGS, _DEVICE
|
| 78 |
+
_DEVICE = device or "cpu"
|
| 79 |
+
|
| 80 |
+
if not clustering_ckpt or not os.path.isfile(clustering_ckpt):
|
| 81 |
+
return "⚠️ Please provide a valid path to the clustering checkpoint."
|
| 82 |
+
|
| 83 |
+
energy_pid = energy_pid_ckpt if (energy_pid_ckpt and os.path.isfile(energy_pid_ckpt)) else None
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
from src.inference import load_model
|
| 87 |
+
_MODEL, _ARGS = load_model(
|
| 88 |
+
clustering_ckpt=clustering_ckpt,
|
| 89 |
+
energy_pid_ckpt=energy_pid,
|
| 90 |
+
device=_DEVICE,
|
| 91 |
+
)
|
| 92 |
+
msg = f"✅ Model loaded on **{_DEVICE}**"
|
| 93 |
+
if energy_pid:
|
| 94 |
+
msg += " (clustering + energy/PID correction)"
|
| 95 |
+
else:
|
| 96 |
+
msg += " (clustering only — no energy/PID correction)"
|
| 97 |
+
return msg
|
| 98 |
+
except Exception:
|
| 99 |
+
return f"❌ Failed to load model:\n```\n{traceback.format_exc()}\n```"
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ---------------------------------------------------------------------------
|
| 103 |
+
# Event loading helpers
|
| 104 |
+
# ---------------------------------------------------------------------------
|
| 105 |
+
|
| 106 |
+
def _count_events_in_parquet(parquet_path: str) -> str:
|
| 107 |
+
"""Return a short info string about the parquet file."""
|
| 108 |
+
if not parquet_path or not os.path.isfile(parquet_path):
|
| 109 |
+
return "No file selected"
|
| 110 |
+
try:
|
| 111 |
+
from src.inference import load_event_from_parquet
|
| 112 |
+
from src.data.fileio import _read_parquet
|
| 113 |
+
table = _read_parquet(parquet_path)
|
| 114 |
+
n = len(table["X_track"])
|
| 115 |
+
return f"File has **{n}** events (indices 0–{n-1})"
|
| 116 |
+
except Exception as e:
|
| 117 |
+
return f"Error reading file: {e}"
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _load_event_into_csv(parquet_path: str, event_index: int):
|
| 121 |
+
"""Load an event from a parquet file and return CSV strings for the text fields."""
|
| 122 |
+
if not parquet_path or not os.path.isfile(parquet_path):
|
| 123 |
+
return "", "", "", "", "", "⚠️ Please provide a valid parquet file path."
|
| 124 |
+
try:
|
| 125 |
+
from src.inference import load_event_from_parquet
|
| 126 |
+
event = load_event_from_parquet(parquet_path, int(event_index))
|
| 127 |
+
|
| 128 |
+
hits_arr = np.asarray(event.get("X_hit", []))
|
| 129 |
+
tracks_arr = np.asarray(event.get("X_track", []))
|
| 130 |
+
particles_arr = np.asarray(event.get("X_gen", []))
|
| 131 |
+
pandora_arr = np.asarray(event.get("X_pandora", []))
|
| 132 |
+
|
| 133 |
+
def _arr_to_csv(arr):
|
| 134 |
+
if arr.ndim != 2:
|
| 135 |
+
return ""
|
| 136 |
+
return "\n".join(",".join(str(v) for v in row) for row in arr)
|
| 137 |
+
|
| 138 |
+
def _1d_to_csv(arr):
|
| 139 |
+
if len(arr) == 0:
|
| 140 |
+
return ""
|
| 141 |
+
return ",".join(str(int(v)) for v in arr)
|
| 142 |
+
|
| 143 |
+
pfo_calohit = np.asarray(event.get("pfo_calohit", []), dtype=np.int64)
|
| 144 |
+
pfo_track = np.asarray(event.get("pfo_track", []), dtype=np.int64)
|
| 145 |
+
calohit_csv = _1d_to_csv(pfo_calohit)
|
| 146 |
+
track_csv = _1d_to_csv(pfo_track)
|
| 147 |
+
if calohit_csv and track_csv:
|
| 148 |
+
pfo_links_csv = calohit_csv + "\n" + track_csv
|
| 149 |
+
elif calohit_csv:
|
| 150 |
+
pfo_links_csv = calohit_csv
|
| 151 |
+
elif track_csv:
|
| 152 |
+
pfo_links_csv = "\n" + track_csv
|
| 153 |
+
else:
|
| 154 |
+
pfo_links_csv = ""
|
| 155 |
+
|
| 156 |
+
return (
|
| 157 |
+
_arr_to_csv(hits_arr),
|
| 158 |
+
_arr_to_csv(tracks_arr),
|
| 159 |
+
_arr_to_csv(particles_arr),
|
| 160 |
+
_arr_to_csv(pandora_arr),
|
| 161 |
+
pfo_links_csv,
|
| 162 |
+
f"✅ Loaded event **{int(event_index)}**: "
|
| 163 |
+
f"{hits_arr.shape[0] if hits_arr.ndim == 2 else 0} hits, "
|
| 164 |
+
f"{tracks_arr.shape[0] if tracks_arr.ndim == 2 else 0} tracks, "
|
| 165 |
+
f"{particles_arr.shape[0] if particles_arr.ndim == 2 else 0} MC particles, "
|
| 166 |
+
f"{pandora_arr.shape[0] if pandora_arr.ndim == 2 else 0} Pandora PFOs",
|
| 167 |
+
)
|
| 168 |
+
except Exception as e:
|
| 169 |
+
return "", "", "", "", "", f"❌ Error loading event: {e}"
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _build_cluster_plot(hit_cluster_df: pd.DataFrame) -> go.Figure:
|
| 173 |
+
"""Build an interactive 3D scatter plot of hits colored by cluster ID."""
|
| 174 |
+
if hit_cluster_df.empty:
|
| 175 |
+
fig = go.Figure()
|
| 176 |
+
fig.update_layout(title="No hit data available", height=600)
|
| 177 |
+
return fig
|
| 178 |
+
|
| 179 |
+
df = hit_cluster_df.copy()
|
| 180 |
+
|
| 181 |
+
# Drop rows with NaN/Inf coordinates
|
| 182 |
+
for col in ("x", "y", "z", "hit_energy"):
|
| 183 |
+
df[col] = pd.to_numeric(df[col], errors="coerce")
|
| 184 |
+
df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["x", "y", "z", "hit_energy"])
|
| 185 |
+
if df.empty:
|
| 186 |
+
fig = go.Figure()
|
| 187 |
+
fig.update_layout(title="No valid hit data (all NaN/Inf)", height=600)
|
| 188 |
+
return fig
|
| 189 |
+
|
| 190 |
+
# Normalize hit energies for marker sizes
|
| 191 |
+
energies = df["hit_energy"].values.astype(float)
|
| 192 |
+
e_min, e_max = float(energies.min()), float(energies.max())
|
| 193 |
+
if e_max > e_min:
|
| 194 |
+
norm_e = (energies - e_min) / (e_max - e_min)
|
| 195 |
+
else:
|
| 196 |
+
norm_e = np.ones_like(energies) * 0.5 # midpoint when all equal
|
| 197 |
+
marker_sizes = 3 + norm_e * 12 # min size 3, max size 15
|
| 198 |
+
|
| 199 |
+
# Build per-hit hover text (avoids mixed-type customdata serialization issues)
|
| 200 |
+
df["_hover"] = (
|
| 201 |
+
"<b>" + df["hit_type"].astype(str) + "</b> hit #" + df["hit_index"].astype(int).astype(str) + "<br>"
|
| 202 |
+
+ "Cluster: " + df["cluster_id"].astype(int).astype(str) + "<br>"
|
| 203 |
+
+ "Energy: " + df["hit_energy"].map(lambda v: f"{v:.4f}") + "<br>"
|
| 204 |
+
+ "x: " + df["x"].map(lambda v: f"{v:.2f}")
|
| 205 |
+
+ ", y: " + df["y"].map(lambda v: f"{v:.2f}")
|
| 206 |
+
+ ", z: " + df["z"].map(lambda v: f"{v:.2f}")
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
cluster_ids = df["cluster_id"].values
|
| 210 |
+
unique_clusters = sorted(set(int(c) for c in cluster_ids))
|
| 211 |
+
|
| 212 |
+
fig = go.Figure()
|
| 213 |
+
for cid in unique_clusters:
|
| 214 |
+
mask = cluster_ids == cid
|
| 215 |
+
subset = df[mask]
|
| 216 |
+
sizes = marker_sizes[mask].tolist()
|
| 217 |
+
label = "noise" if cid == 0 else f"cluster {cid}"
|
| 218 |
+
fig.add_trace(go.Scatter3d(
|
| 219 |
+
x=subset["x"].tolist(),
|
| 220 |
+
y=subset["y"].tolist(),
|
| 221 |
+
z=subset["z"].tolist(),
|
| 222 |
+
mode="markers",
|
| 223 |
+
name=label,
|
| 224 |
+
marker=dict(size=sizes, opacity=0.8),
|
| 225 |
+
hovertext=subset["_hover"].tolist(),
|
| 226 |
+
hoverinfo="text",
|
| 227 |
+
))
|
| 228 |
+
|
| 229 |
+
fig.update_layout(
|
| 230 |
+
title="Hit → Cluster 3D Map",
|
| 231 |
+
scene=dict(xaxis_title="x", yaxis_title="y", zaxis_title="z"),
|
| 232 |
+
legend_title="Cluster",
|
| 233 |
+
height=600,
|
| 234 |
+
margin=dict(l=0, r=0, t=40, b=0),
|
| 235 |
+
)
|
| 236 |
+
return fig
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _build_pandora_cluster_plot(hit_cluster_df: pd.DataFrame) -> go.Figure:
|
| 240 |
+
"""Build an interactive 3D scatter plot of hits colored by Pandora cluster ID."""
|
| 241 |
+
if hit_cluster_df.empty or "pandora_cluster_id" not in hit_cluster_df.columns:
|
| 242 |
+
fig = go.Figure()
|
| 243 |
+
fig.update_layout(title="No Pandora cluster data available", height=600)
|
| 244 |
+
return fig
|
| 245 |
+
|
| 246 |
+
df = hit_cluster_df.copy()
|
| 247 |
+
|
| 248 |
+
# Only keep rows that have valid Pandora assignments (pandora_cluster_id >= 0)
|
| 249 |
+
for col in ("x", "y", "z", "hit_energy"):
|
| 250 |
+
df[col] = pd.to_numeric(df[col], errors="coerce")
|
| 251 |
+
df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["x", "y", "z", "hit_energy"])
|
| 252 |
+
if df.empty:
|
| 253 |
+
fig = go.Figure()
|
| 254 |
+
fig.update_layout(title="No valid hit data for Pandora plot (all NaN/Inf)", height=600)
|
| 255 |
+
return fig
|
| 256 |
+
|
| 257 |
+
# Normalize hit energies for marker sizes
|
| 258 |
+
energies = df["hit_energy"].values.astype(float)
|
| 259 |
+
e_min, e_max = float(energies.min()), float(energies.max())
|
| 260 |
+
if e_max > e_min:
|
| 261 |
+
norm_e = (energies - e_min) / (e_max - e_min)
|
| 262 |
+
else:
|
| 263 |
+
norm_e = np.ones_like(energies) * 0.5
|
| 264 |
+
marker_sizes = 3 + norm_e * 12
|
| 265 |
+
|
| 266 |
+
# Build per-hit hover text
|
| 267 |
+
df["_hover"] = (
|
| 268 |
+
"<b>" + df["hit_type"].astype(str) + "</b> hit #" + df["hit_index"].astype(int).astype(str) + "<br>"
|
| 269 |
+
+ "Pandora cluster: " + df["pandora_cluster_id"].astype(int).astype(str) + "<br>"
|
| 270 |
+
+ "Energy: " + df["hit_energy"].map(lambda v: f"{v:.4f}") + "<br>"
|
| 271 |
+
+ "x: " + df["x"].map(lambda v: f"{v:.2f}")
|
| 272 |
+
+ ", y: " + df["y"].map(lambda v: f"{v:.2f}")
|
| 273 |
+
+ ", z: " + df["z"].map(lambda v: f"{v:.2f}")
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
pandora_ids = df["pandora_cluster_id"].values
|
| 277 |
+
unique_clusters = sorted(set(int(c) for c in pandora_ids))
|
| 278 |
+
|
| 279 |
+
fig = go.Figure()
|
| 280 |
+
for cid in unique_clusters:
|
| 281 |
+
mask = pandora_ids == cid
|
| 282 |
+
subset = df[mask]
|
| 283 |
+
sizes = marker_sizes[mask].tolist()
|
| 284 |
+
label = "unassigned" if cid == -1 else f"PFO {cid}"
|
| 285 |
+
fig.add_trace(go.Scatter3d(
|
| 286 |
+
x=subset["x"].tolist(),
|
| 287 |
+
y=subset["y"].tolist(),
|
| 288 |
+
z=subset["z"].tolist(),
|
| 289 |
+
mode="markers",
|
| 290 |
+
name=label,
|
| 291 |
+
marker=dict(size=sizes, opacity=0.8),
|
| 292 |
+
hovertext=subset["_hover"].tolist(),
|
| 293 |
+
hoverinfo="text",
|
| 294 |
+
))
|
| 295 |
+
|
| 296 |
+
fig.update_layout(
|
| 297 |
+
title="Hit → Pandora Cluster 3D Map",
|
| 298 |
+
scene=dict(xaxis_title="x", yaxis_title="y", zaxis_title="z"),
|
| 299 |
+
legend_title="Pandora PFO",
|
| 300 |
+
height=600,
|
| 301 |
+
margin=dict(l=0, r=0, t=40, b=0),
|
| 302 |
+
)
|
| 303 |
+
return fig
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def _build_clustering_space_plot(hit_cluster_df: pd.DataFrame) -> go.Figure:
|
| 307 |
+
"""Build an interactive 3D scatter plot of hits in the learned clustering space."""
|
| 308 |
+
if hit_cluster_df.empty or "cluster_x" not in hit_cluster_df.columns:
|
| 309 |
+
fig = go.Figure()
|
| 310 |
+
fig.update_layout(title="No clustering-space data available", height=600)
|
| 311 |
+
return fig
|
| 312 |
+
|
| 313 |
+
df = hit_cluster_df.copy()
|
| 314 |
+
|
| 315 |
+
# Drop rows with NaN/Inf coordinates
|
| 316 |
+
for col in ("cluster_x", "cluster_y", "cluster_z", "hit_energy"):
|
| 317 |
+
df[col] = pd.to_numeric(df[col], errors="coerce")
|
| 318 |
+
df = df.replace([np.inf, -np.inf], np.nan).dropna(
|
| 319 |
+
subset=["cluster_x", "cluster_y", "cluster_z", "hit_energy"]
|
| 320 |
+
)
|
| 321 |
+
if df.empty:
|
| 322 |
+
fig = go.Figure()
|
| 323 |
+
fig.update_layout(title="No valid clustering-space data (all NaN/Inf)", height=600)
|
| 324 |
+
return fig
|
| 325 |
+
|
| 326 |
+
# Normalize hit energies for marker sizes
|
| 327 |
+
energies = df["hit_energy"].values.astype(float)
|
| 328 |
+
e_min, e_max = float(energies.min()), float(energies.max())
|
| 329 |
+
if e_max > e_min:
|
| 330 |
+
norm_e = (energies - e_min) / (e_max - e_min)
|
| 331 |
+
else:
|
| 332 |
+
norm_e = np.ones_like(energies) * 0.5
|
| 333 |
+
marker_sizes = 3 + norm_e * 12
|
| 334 |
+
|
| 335 |
+
# Build per-hit hover text
|
| 336 |
+
df["_hover"] = (
|
| 337 |
+
"<b>" + df["hit_type"].astype(str) + "</b> hit #" + df["hit_index"].astype(int).astype(str) + "<br>"
|
| 338 |
+
+ "Cluster: " + df["cluster_id"].astype(int).astype(str) + "<br>"
|
| 339 |
+
+ "Energy: " + df["hit_energy"].map(lambda v: f"{v:.4f}") + "<br>"
|
| 340 |
+
+ "cluster_x: " + df["cluster_x"].map(lambda v: f"{v:.4f}")
|
| 341 |
+
+ ", cluster_y: " + df["cluster_y"].map(lambda v: f"{v:.4f}")
|
| 342 |
+
+ ", cluster_z: " + df["cluster_z"].map(lambda v: f"{v:.4f}")
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
cluster_ids = df["cluster_id"].values
|
| 346 |
+
unique_clusters = sorted(set(int(c) for c in cluster_ids))
|
| 347 |
+
|
| 348 |
+
fig = go.Figure()
|
| 349 |
+
for cid in unique_clusters:
|
| 350 |
+
mask = cluster_ids == cid
|
| 351 |
+
subset = df[mask]
|
| 352 |
+
sizes = marker_sizes[mask].tolist()
|
| 353 |
+
label = "noise" if cid == 0 else f"cluster {cid}"
|
| 354 |
+
fig.add_trace(go.Scatter3d(
|
| 355 |
+
x=subset["cluster_x"].tolist(),
|
| 356 |
+
y=subset["cluster_y"].tolist(),
|
| 357 |
+
z=subset["cluster_z"].tolist(),
|
| 358 |
+
mode="markers",
|
| 359 |
+
name=label,
|
| 360 |
+
marker=dict(size=sizes, opacity=0.8),
|
| 361 |
+
hovertext=subset["_hover"].tolist(),
|
| 362 |
+
hoverinfo="text",
|
| 363 |
+
))
|
| 364 |
+
|
| 365 |
+
fig.update_layout(
|
| 366 |
+
title="Clustering Space 3D Map (GATr regressed coordinates)",
|
| 367 |
+
scene=dict(
|
| 368 |
+
xaxis_title="cluster_x",
|
| 369 |
+
yaxis_title="cluster_y",
|
| 370 |
+
zaxis_title="cluster_z",
|
| 371 |
+
),
|
| 372 |
+
legend_title="Cluster",
|
| 373 |
+
height=600,
|
| 374 |
+
margin=dict(l=0, r=0, t=40, b=0),
|
| 375 |
+
)
|
| 376 |
+
return fig
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# ---------------------------------------------------------------------------
|
| 380 |
+
# Main inference entry point for the UI
|
| 381 |
+
# ---------------------------------------------------------------------------
|
| 382 |
+
|
| 383 |
+
def _compute_inv_mass(df, e_col, px_col, py_col, pz_col):
|
| 384 |
+
"""Compute the invariant mass of a system of particles in GeV.
|
| 385 |
+
|
| 386 |
+
Returns the scalar invariant mass m = sqrt(max((ΣE)²−(Σpx)²−(Σpy)²−(Σpz)², 0)),
|
| 387 |
+
or *None* when *df* is empty or the required columns are absent.
|
| 388 |
+
"""
|
| 389 |
+
if df.empty:
|
| 390 |
+
return None
|
| 391 |
+
for col in (e_col, px_col, py_col, pz_col):
|
| 392 |
+
if col not in df.columns:
|
| 393 |
+
return None
|
| 394 |
+
E = float(df[e_col].sum())
|
| 395 |
+
px = float(df[px_col].sum())
|
| 396 |
+
py = float(df[py_col].sum())
|
| 397 |
+
pz = float(df[pz_col].sum())
|
| 398 |
+
m2 = E ** 2 - px ** 2 - py ** 2 - pz ** 2
|
| 399 |
+
return float(np.sqrt(max(m2, 0.0)))
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def _fmt_mass(val):
|
| 403 |
+
"""Format an invariant-mass value (float or None) as a GeV string."""
|
| 404 |
+
return f"{val:.4f} GeV" if val is not None else "N/A"
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def run_inference_ui(
|
| 408 |
+
parquet_path: str,
|
| 409 |
+
event_index: int,
|
| 410 |
+
csv_hits: str,
|
| 411 |
+
csv_tracks: str,
|
| 412 |
+
csv_particles: str,
|
| 413 |
+
csv_pandora: str,
|
| 414 |
+
csv_pfo_links: str = "",
|
| 415 |
+
):
|
| 416 |
+
"""Run inference on a single event, return predicted particles, 3D plots, MC particles and Pandora particles.
|
| 417 |
+
|
| 418 |
+
Returns
|
| 419 |
+
-------
|
| 420 |
+
particles_df : pandas.DataFrame
|
| 421 |
+
cluster_fig : plotly.graph_objects.Figure
|
| 422 |
+
clustering_space_fig : plotly.graph_objects.Figure
|
| 423 |
+
pandora_cluster_fig : plotly.graph_objects.Figure
|
| 424 |
+
mc_particles_df : pandas.DataFrame
|
| 425 |
+
pandora_particles_df : pandas.DataFrame
|
| 426 |
+
inv_mass_summary : str
|
| 427 |
+
"""
|
| 428 |
+
global _MODEL, _ARGS, _DEVICE
|
| 429 |
+
|
| 430 |
+
empty_fig = go.Figure()
|
| 431 |
+
|
| 432 |
+
if _MODEL is None:
|
| 433 |
+
return (
|
| 434 |
+
pd.DataFrame({"error": ["Model not loaded. Please load a model first."]}),
|
| 435 |
+
empty_fig,
|
| 436 |
+
empty_fig,
|
| 437 |
+
empty_fig,
|
| 438 |
+
pd.DataFrame(),
|
| 439 |
+
pd.DataFrame(),
|
| 440 |
+
"",
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
try:
|
| 444 |
+
from src.inference import load_event_from_parquet, run_single_event_inference
|
| 445 |
+
|
| 446 |
+
# Decide input source
|
| 447 |
+
use_parquet = parquet_path and os.path.isfile(parquet_path)
|
| 448 |
+
use_csv = bool(csv_hits and csv_hits.strip())
|
| 449 |
+
|
| 450 |
+
if not use_parquet and not use_csv:
|
| 451 |
+
return (
|
| 452 |
+
pd.DataFrame({"error": ["Provide a parquet file or paste CSV hit data."]}),
|
| 453 |
+
empty_fig,
|
| 454 |
+
empty_fig,
|
| 455 |
+
empty_fig,
|
| 456 |
+
pd.DataFrame(),
|
| 457 |
+
pd.DataFrame(),
|
| 458 |
+
"",
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
if use_csv:
|
| 462 |
+
event = _parse_csv_event(csv_hits, csv_tracks, csv_particles, csv_pandora, csv_pfo_links)
|
| 463 |
+
elif use_parquet:
|
| 464 |
+
event = load_event_from_parquet(parquet_path, int(event_index))
|
| 465 |
+
|
| 466 |
+
particles_df, hit_cluster_df, mc_particles_df, pandora_particles_df = run_single_event_inference(
|
| 467 |
+
event, _MODEL, _ARGS, device=_DEVICE,
|
| 468 |
+
)
|
| 469 |
+
if particles_df.empty:
|
| 470 |
+
particles_df = pd.DataFrame({"info": ["Event produced no clusters (empty graph)."]})
|
| 471 |
+
|
| 472 |
+
cluster_fig = _build_cluster_plot(hit_cluster_df)
|
| 473 |
+
clustering_space_fig = _build_clustering_space_plot(hit_cluster_df)
|
| 474 |
+
pandora_cluster_fig = _build_pandora_cluster_plot(hit_cluster_df)
|
| 475 |
+
|
| 476 |
+
# Compute invariant masses [GeV]
|
| 477 |
+
m_true = _compute_inv_mass(mc_particles_df, "energy", "px", "py", "pz")
|
| 478 |
+
# HitPF uses corrected_energy when available, otherwise energy_sum_hits
|
| 479 |
+
hitpf_e_col = "corrected_energy" if "corrected_energy" in particles_df.columns else "energy_sum_hits"
|
| 480 |
+
m_reco_hitpf = _compute_inv_mass(particles_df, hitpf_e_col, "px", "py", "pz")
|
| 481 |
+
m_reco_pandora = _compute_inv_mass(pandora_particles_df, "energy", "px", "py", "pz")
|
| 482 |
+
|
| 483 |
+
inv_mass_summary = (
|
| 484 |
+
f"**Invariant mass (sum of all particle 4-vectors)**\n\n"
|
| 485 |
+
f"| Algorithm | m [GeV] |\n"
|
| 486 |
+
f"|---|---|\n"
|
| 487 |
+
f"| m_true (MC truth) | {_fmt_mass(m_true)} |\n"
|
| 488 |
+
f"| m_reco (HitPF) | {_fmt_mass(m_reco_hitpf)} |\n"
|
| 489 |
+
f"| m_reco (Pandora) | {_fmt_mass(m_reco_pandora)} |"
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
return particles_df, cluster_fig, clustering_space_fig, pandora_cluster_fig, mc_particles_df, pandora_particles_df, inv_mass_summary
|
| 493 |
+
|
| 494 |
+
except Exception:
|
| 495 |
+
err = traceback.format_exc()
|
| 496 |
+
return (
|
| 497 |
+
pd.DataFrame({"error": [err]}),
|
| 498 |
+
empty_fig,
|
| 499 |
+
empty_fig,
|
| 500 |
+
empty_fig,
|
| 501 |
+
pd.DataFrame(),
|
| 502 |
+
pd.DataFrame(),
|
| 503 |
+
"",
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def _parse_csv_event(csv_hits: str, csv_tracks: str, csv_particles: str, csv_pandora: str = "", csv_pfo_links: str = ""):
|
| 508 |
+
"""Parse user-provided CSV text into the dict-of-arrays format expected by
|
| 509 |
+
``create_graph``.
|
| 510 |
+
|
| 511 |
+
Expected CSV columns for hits (X_hit) — 11 columns:
|
| 512 |
+
0: hit_x — hit position x [mm]
|
| 513 |
+
1: hit_y — hit position y [mm]
|
| 514 |
+
2: hit_z — hit position z [mm]
|
| 515 |
+
3: hit_px — hit momentum px [GeV] (0 for calo hits)
|
| 516 |
+
4: hit_py — hit momentum py [GeV] (0 for calo hits)
|
| 517 |
+
5: hit_energy — hit energy deposit [GeV]
|
| 518 |
+
6: hit_x_calo — hit position x at calorimeter surface [mm] (used as 3D position by the model)
|
| 519 |
+
7: hit_y_calo — hit position y at calorimeter surface [mm]
|
| 520 |
+
8: hit_z_calo — hit position z at calorimeter surface [mm]
|
| 521 |
+
9: (unused) — reserved column (set to 0)
|
| 522 |
+
10: hit_type — hit sub-detector type: 1 = ECAL, 2 = HCAL, 3 = muon system
|
| 523 |
+
|
| 524 |
+
Expected CSV columns for tracks (X_track) — 25 columns (padded with
|
| 525 |
+
zeros if fewer are provided; minimum 17):
|
| 526 |
+
0: elemtype — element type (always 1 for tracks)
|
| 527 |
+
1–4: (unused) — reserved columns (set to 0)
|
| 528 |
+
5: p — track momentum magnitude |p| [GeV]
|
| 529 |
+
6: px_IP — track px at interaction point [GeV]
|
| 530 |
+
7: py_IP — track py at interaction point [GeV]
|
| 531 |
+
8: pz_IP — track pz at interaction point [GeV]
|
| 532 |
+
9–11: (unused) — reserved columns (set to 0)
|
| 533 |
+
12: ref_x_calo — track reference-point x at calorimeter [mm]
|
| 534 |
+
13: ref_y_calo — track reference-point y at calorimeter [mm]
|
| 535 |
+
14: ref_z_calo — track reference-point z at calorimeter [mm]
|
| 536 |
+
15: chi2 — track-fit chi-squared
|
| 537 |
+
16: ndf — track-fit number of degrees of freedom
|
| 538 |
+
17–21: (unused) — reserved columns (set to 0)
|
| 539 |
+
22: px_calo — track momentum x component at calorimeter [GeV]
|
| 540 |
+
23: py_calo — track momentum y component at calorimeter [GeV]
|
| 541 |
+
24: pz_calo — track momentum z component at calorimeter [GeV]
|
| 542 |
+
|
| 543 |
+
Expected CSV columns for particles / MC truth (X_gen) — 18 columns:
|
| 544 |
+
0: pid — PDG particle ID (e.g. 211, 22, 11, 13)
|
| 545 |
+
1: gen_status — generator status code
|
| 546 |
+
2: isDecayedInCalo — 1 if decayed in calorimeter, else 0
|
| 547 |
+
3: isDecayedInTracker — 1 if decayed in tracker, else 0
|
| 548 |
+
4: theta — polar angle [rad]
|
| 549 |
+
5: phi — azimuthal angle [rad]
|
| 550 |
+
6: (unused) — reserved (set to 0)
|
| 551 |
+
7: (unused) — reserved (set to 0)
|
| 552 |
+
8: energy — true particle energy [GeV]
|
| 553 |
+
9: (unused) — reserved (set to 0)
|
| 554 |
+
10: mass — particle mass [GeV]
|
| 555 |
+
11: momentum — momentum magnitude |p| [GeV]
|
| 556 |
+
12: px — momentum x component [GeV]
|
| 557 |
+
13: py — momentum y component [GeV]
|
| 558 |
+
14: pz — momentum z component [GeV]
|
| 559 |
+
15: vx — production vertex x [mm]
|
| 560 |
+
16: vy — production vertex y [mm]
|
| 561 |
+
17: vz — production vertex z [mm]
|
| 562 |
+
|
| 563 |
+
PFO links (csv_pfo_links) — two lines of comma-separated integers:
|
| 564 |
+
Line 1: pfo_calohit — one PFO index per calorimeter hit (-1 = unassigned)
|
| 565 |
+
Line 2: pfo_track — one PFO index per track (-1 = unassigned)
|
| 566 |
+
"""
|
| 567 |
+
import io
|
| 568 |
+
import awkward as ak
|
| 569 |
+
|
| 570 |
+
def _read(text, min_cols=1):
|
| 571 |
+
if not text or not text.strip():
|
| 572 |
+
return np.zeros((0, min_cols), dtype=np.float64)
|
| 573 |
+
df = pd.read_csv(io.StringIO(text), header=None)
|
| 574 |
+
return df.values.astype(np.float64)
|
| 575 |
+
|
| 576 |
+
hits_arr = _read(csv_hits, 11)
|
| 577 |
+
tracks_arr = _read(csv_tracks, 25)
|
| 578 |
+
particles_arr = _read(csv_particles, 18)
|
| 579 |
+
pandora_arr = _read(csv_pandora, 9)
|
| 580 |
+
|
| 581 |
+
# Pad tracks to 25 columns if needed
|
| 582 |
+
if tracks_arr.shape[1] < 25 and tracks_arr.shape[0] > 0:
|
| 583 |
+
pad = np.zeros((tracks_arr.shape[0], 25 - tracks_arr.shape[1]))
|
| 584 |
+
tracks_arr = np.concatenate([tracks_arr, pad], axis=1)
|
| 585 |
+
|
| 586 |
+
# Build ygen_hit / ygen_track (particle link per hit — use -1 for unknown)
|
| 587 |
+
ygen_hit = np.full(len(hits_arr), -1, dtype=np.int64)
|
| 588 |
+
ygen_track = np.full(len(tracks_arr), -1, dtype=np.int64)
|
| 589 |
+
|
| 590 |
+
# Parse PFO link arrays (hit → Pandora cluster mapping)
|
| 591 |
+
pfo_calohit = np.array([], dtype=np.int64)
|
| 592 |
+
pfo_track = np.array([], dtype=np.int64)
|
| 593 |
+
if csv_pfo_links and csv_pfo_links.strip():
|
| 594 |
+
lines = csv_pfo_links.strip().split("\n")
|
| 595 |
+
if len(lines) >= 1 and lines[0].strip():
|
| 596 |
+
pfo_calohit = np.array(
|
| 597 |
+
[int(v) for v in lines[0].strip().split(",")], dtype=np.int64
|
| 598 |
+
)
|
| 599 |
+
if len(lines) >= 2 and lines[1].strip():
|
| 600 |
+
pfo_track = np.array(
|
| 601 |
+
[int(v) for v in lines[1].strip().split(",")], dtype=np.int64
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
event = {
|
| 605 |
+
"X_hit": hits_arr,
|
| 606 |
+
"X_track": tracks_arr,
|
| 607 |
+
"X_gen": particles_arr,
|
| 608 |
+
"X_pandora": pandora_arr,
|
| 609 |
+
"ygen_hit": ygen_hit,
|
| 610 |
+
"ygen_track": ygen_track,
|
| 611 |
+
"pfo_calohit": pfo_calohit,
|
| 612 |
+
"pfo_track": pfo_track,
|
| 613 |
+
}
|
| 614 |
+
return event
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
# ---------------------------------------------------------------------------
|
| 618 |
+
# Build the Gradio interface
|
| 619 |
+
# ---------------------------------------------------------------------------
|
| 620 |
+
|
| 621 |
+
def build_app():
|
| 622 |
+
with gr.Blocks(title="HitPF — Single-event MLPF Inference") as demo:
|
| 623 |
+
gr.Markdown(
|
| 624 |
+
"# HitPF — Single-event MLPF Inference\n"
|
| 625 |
+
"Run the GATr-based particle-flow reconstruction on a single event.\n\n"
|
| 626 |
+
"**Steps:** 1) Load model checkpoints 2) Select an event 3) Run inference"
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
# ---- Model loading ----
|
| 630 |
+
with gr.Accordion("1 · Load Model", open=True):
|
| 631 |
+
with gr.Row():
|
| 632 |
+
clustering_ckpt = gr.Textbox(
|
| 633 |
+
label="Clustering checkpoint (.ckpt)",
|
| 634 |
+
value="model_clustering.ckpt",
|
| 635 |
+
placeholder="/path/to/clustering.ckpt",
|
| 636 |
+
)
|
| 637 |
+
energy_pid_ckpt = gr.Textbox(
|
| 638 |
+
label="Energy / PID checkpoint (.ckpt) — optional",
|
| 639 |
+
value="model_e_pid.ckpt",
|
| 640 |
+
placeholder="/path/to/energy_pid.ckpt",
|
| 641 |
+
)
|
| 642 |
+
device_dd = gr.Dropdown(
|
| 643 |
+
choices=["cpu", "cuda:0", "cuda:1"],
|
| 644 |
+
value="cpu",
|
| 645 |
+
label="Device",
|
| 646 |
+
)
|
| 647 |
+
load_btn = gr.Button("Load model")
|
| 648 |
+
load_status = gr.Markdown("")
|
| 649 |
+
load_btn.click(
|
| 650 |
+
fn=load_model_ui,
|
| 651 |
+
inputs=[clustering_ckpt, energy_pid_ckpt, device_dd],
|
| 652 |
+
outputs=load_status,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
# ---- Event selection ----
|
| 656 |
+
with gr.Accordion("2 · Select Event", open=True):
|
| 657 |
+
gr.Markdown("**Option A** — from a parquet file:")
|
| 658 |
+
with gr.Row():
|
| 659 |
+
parquet_path = gr.Textbox(
|
| 660 |
+
label="Parquet file path",
|
| 661 |
+
value="test_data.parquet",
|
| 662 |
+
placeholder="/path/to/events.parquet",
|
| 663 |
+
)
|
| 664 |
+
event_idx = gr.Number(label="Event index", value=0, precision=0)
|
| 665 |
+
parquet_info = gr.Markdown("")
|
| 666 |
+
parquet_path.change(
|
| 667 |
+
fn=_count_events_in_parquet,
|
| 668 |
+
inputs=parquet_path,
|
| 669 |
+
outputs=parquet_info,
|
| 670 |
+
)
|
| 671 |
+
load_event_btn = gr.Button("Load event from parquet")
|
| 672 |
+
load_event_status = gr.Markdown("")
|
| 673 |
+
|
| 674 |
+
gr.Markdown(
|
| 675 |
+
"---\n**Option B** — paste CSV data (one row per hit/track/particle, "
|
| 676 |
+
"no header, comma-separated):\n"
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
csv_hits = gr.Textbox(
|
| 680 |
+
label="Hits CSV (11 columns)",
|
| 681 |
+
lines=4,
|
| 682 |
+
placeholder=(
|
| 683 |
+
"Example (one ECAL hit, one HCAL hit):\n"
|
| 684 |
+
"0,0,0,0,0,1.23,1800.5,200.3,100.1,0,1\n"
|
| 685 |
+
"0,0,0,0,0,0.45,1900.2,-50.1,300.7,0,2"
|
| 686 |
+
),
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
csv_tracks = gr.Textbox(
|
| 690 |
+
label="Tracks CSV (25 columns; leave empty if none)",
|
| 691 |
+
lines=3,
|
| 692 |
+
placeholder=(
|
| 693 |
+
"Example (one track with p≈5 GeV):\n"
|
| 694 |
+
"1,0,0,0,0,5.0,3.0,2.0,3.3,0,0,0,1800.0,150.0,90.0,12.5,8,0,0,0,0,0,2.9,1.9,3.2"
|
| 695 |
+
),
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
csv_particles = gr.Textbox(
|
| 699 |
+
label="Particles (MC truth) CSV (18 columns; optional)",
|
| 700 |
+
lines=3,
|
| 701 |
+
placeholder=(
|
| 702 |
+
"Example (one pion, one photon):\n"
|
| 703 |
+
"211,1,0,0,1.2,0.5,0,0,5.2,0,0.1396,5.198,3.1,2.0,3.3,0,0,0\n"
|
| 704 |
+
"22,1,0,0,0.8,2.1,0,0,1.5,0,0,1.5,0.5,-0.3,1.38,0,0,0"
|
| 705 |
+
),
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
csv_pandora = gr.Textbox(
|
| 709 |
+
label="Pandora PFOs CSV (9 columns; optional)",
|
| 710 |
+
lines=3,
|
| 711 |
+
placeholder=(
|
| 712 |
+
"Columns: pid, px, py, pz, ref_x, ref_y, ref_z, energy, momentum\n"
|
| 713 |
+
"Example (one charged pion PFO):\n"
|
| 714 |
+
"211,3.0,2.0,3.3,1800.0,150.0,90.0,5.2,5.198"
|
| 715 |
+
),
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
csv_pfo_links = gr.Textbox(
|
| 719 |
+
label="Hit → Pandora Cluster links (optional; loaded from parquet)",
|
| 720 |
+
lines=2,
|
| 721 |
+
placeholder=(
|
| 722 |
+
"Line 1: PFO index per calo hit (comma-separated, -1 = unassigned)\n"
|
| 723 |
+
"Line 2: PFO index per track (comma-separated, -1 = unassigned)"
|
| 724 |
+
),
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
load_event_btn.click(
|
| 728 |
+
fn=_load_event_into_csv,
|
| 729 |
+
inputs=[parquet_path, event_idx],
|
| 730 |
+
outputs=[csv_hits, csv_tracks, csv_particles, csv_pandora, csv_pfo_links, load_event_status],
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
# ---- Run inference ----
|
| 734 |
+
with gr.Accordion("3 · Results", open=True):
|
| 735 |
+
run_btn = gr.Button("▶ Run Inference", variant="primary")
|
| 736 |
+
inv_mass_output = gr.Markdown("")
|
| 737 |
+
gr.Markdown("### Predicted Particles (HitPF)")
|
| 738 |
+
particles_table = gr.Dataframe(label="Predicted particles")
|
| 739 |
+
gr.Markdown("### MC Truth Particles")
|
| 740 |
+
mc_particles_table = gr.Dataframe(label="MC truth particles (for comparison)")
|
| 741 |
+
gr.Markdown("### Pandora Particles")
|
| 742 |
+
pandora_particles_table = gr.Dataframe(label="Pandora PFO particles (for comparison)")
|
| 743 |
+
with gr.Row():
|
| 744 |
+
with gr.Column():
|
| 745 |
+
gr.Markdown("### Hit → HitPF Cluster 3D Map")
|
| 746 |
+
cluster_plot = gr.Plot(label="Hit-cluster 3D scatter (color = HitPF cluster, size = energy)")
|
| 747 |
+
with gr.Column():
|
| 748 |
+
gr.Markdown("### Hit → Pandora Cluster 3D Map")
|
| 749 |
+
pandora_cluster_plot = gr.Plot(label="Hit-cluster 3D scatter (color = Pandora PFO, size = energy)")
|
| 750 |
+
gr.Markdown("### Clustering Space 3D Map")
|
| 751 |
+
clustering_space_plot = gr.Plot(label="Clustering space 3D scatter (GATr regressed coordinates)")
|
| 752 |
+
|
| 753 |
+
run_btn.click(
|
| 754 |
+
fn=run_inference_ui,
|
| 755 |
+
inputs=[parquet_path, event_idx, csv_hits, csv_tracks, csv_particles, csv_pandora, csv_pfo_links],
|
| 756 |
+
outputs=[particles_table, cluster_plot, clustering_space_plot, pandora_cluster_plot, mc_particles_table, pandora_particles_table, inv_mass_output],
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
return demo
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
# ---------------------------------------------------------------------------
|
| 763 |
+
|
| 764 |
+
if __name__ == "__main__":
|
| 765 |
+
ap = argparse.ArgumentParser(description="HitPF Gradio UI")
|
| 766 |
+
ap.add_argument("--device", default="cpu", help="Default device (cpu / cuda:0 / …)")
|
| 767 |
+
ap.add_argument("--share", action="store_true", help="Create a public Gradio link")
|
| 768 |
+
cli_args = ap.parse_args()
|
| 769 |
+
_set_device(cli_args.device)
|
| 770 |
+
|
| 771 |
+
demo = build_app()
|
| 772 |
+
demo.launch(share=cli_args.share)
|
config_files/config_hits_track_v4.yaml
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This one uses px, py, pz instead of theta, phi, to avoid possible errors
|
| 2 |
+
|
| 3 |
+
graph_config:
|
| 4 |
+
only_hits: false
|
| 5 |
+
prediction: true
|
| 6 |
+
muons: true
|
| 7 |
+
custom_model_kwargs:
|
| 8 |
+
# add custom model kwargs here
|
| 9 |
+
n_postgn_dense_blocks: 4
|
| 10 |
+
clust_space_norm: none
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
#treename:
|
| 15 |
+
selection:
|
| 16 |
+
### use `&`, `|`, `~` for logical operations on numpy arrays
|
| 17 |
+
### can use functions from `math`, `np` (numpy), and `awkward` in the expression
|
| 18 |
+
#(jet_tightId==1) & (jet_no<2) & (fj_pt>200) & (fj_pt<2500) & (((sample_isQCD==0) & (fj_isQCD==0)) | ((sample_isQCD==1) & (fj_isQCD==1))) & (event_no%7!=0)
|
| 19 |
+
#(recojet_e>=5)
|
| 20 |
+
|
| 21 |
+
test_time_selection:
|
| 22 |
+
### selection to apply at test time (i.e., when running w/ --predict)
|
| 23 |
+
#(jet_tightId==1) & (jet_no<2) & (fj_pt>200) & (fj_pt<2500) & (((sample_isQCD==0) & (fj_isQCD==0)) | ((sample_isQCD==1) & (fj_isQCD==1))) & (event_no%7==0)
|
| 24 |
+
#(recojet_e<5)
|
| 25 |
+
|
| 26 |
+
new_variables:
|
| 27 |
+
### [format] name: formula
|
| 28 |
+
### can use functions from `math`, `np` (numpy), and `awkward` in the expression
|
| 29 |
+
#pfcand_mask: awkward.JaggedArray.ones_like(pfcand_etarel)
|
| 30 |
+
#sv_mask: awkward.JaggedArray.ones_like(sv_etarel)
|
| 31 |
+
#pfcand_mask: awkward.JaggedArray.ones_like(pfcand_e)
|
| 32 |
+
hit_mask: ak.ones_like(hit_e)
|
| 33 |
+
part_mask: ak.ones_like(part_p)
|
| 34 |
+
hit_e_nn: hit_e
|
| 35 |
+
part_p1: part_p
|
| 36 |
+
part_theta1: part_theta
|
| 37 |
+
part_phi1: part_phi
|
| 38 |
+
part_m1: part_m
|
| 39 |
+
part_pid1: part_pid
|
| 40 |
+
|
| 41 |
+
preprocess:
|
| 42 |
+
### method: [manual, auto] - whether to use manually specified parameters for variable standardization
|
| 43 |
+
### [note]: `[var]_mask` will not be transformed even if `method=auto`
|
| 44 |
+
method: auto
|
| 45 |
+
### data_fraction: fraction of events to use when calculating the mean/scale for the standardization
|
| 46 |
+
data_fraction: 0.1
|
| 47 |
+
|
| 48 |
+
inputs:
|
| 49 |
+
pf_points:
|
| 50 |
+
pad_mode: wrap
|
| 51 |
+
length: 25000
|
| 52 |
+
vars:
|
| 53 |
+
- [hit_x, null]
|
| 54 |
+
- [hit_y, null]
|
| 55 |
+
- [hit_z, null]
|
| 56 |
+
- [hit_px, null]
|
| 57 |
+
- [hit_py, null]
|
| 58 |
+
- [hit_pz, null]
|
| 59 |
+
pf_points_pfo:
|
| 60 |
+
pad_mode: wrap
|
| 61 |
+
length: 25000
|
| 62 |
+
vars:
|
| 63 |
+
- [hit__pandora_px, null]
|
| 64 |
+
- [hit__pandora_py, null]
|
| 65 |
+
- [hit__pandora_pz, null]
|
| 66 |
+
- [hit__pandora_x, null]
|
| 67 |
+
- [hit__pandora_y, null]
|
| 68 |
+
- [hit__pandora_z, null]
|
| 69 |
+
- [pandora_pid, null]
|
| 70 |
+
pf_features:
|
| 71 |
+
pad_mode: wrap
|
| 72 |
+
length: 25000
|
| 73 |
+
vars:
|
| 74 |
+
### [format 1]: var_name (no transformation)
|
| 75 |
+
### [format 2]: [var_name,
|
| 76 |
+
### subtract_by(optional, default=None, no transf. if preprocess.method=manual, auto transf. if preprocess.method=auto),
|
| 77 |
+
### multiply_by(optional, default=1),
|
| 78 |
+
### clip_min(optional, default=-5),
|
| 79 |
+
### clip_max(optional, default=5),
|
| 80 |
+
### pad_value(optional, default=0)]
|
| 81 |
+
|
| 82 |
+
- [hit_p, null]
|
| 83 |
+
- [hit_e, null]
|
| 84 |
+
- [part_theta , null]
|
| 85 |
+
- [part_phi , null]
|
| 86 |
+
- [part_p , null]
|
| 87 |
+
- [part_m, null]
|
| 88 |
+
- [part_pid, null]
|
| 89 |
+
- [part_isDecayedInCalorimeter, null]
|
| 90 |
+
- [part_isDecayedInTracker, null]
|
| 91 |
+
- [hit_pandora_cluster_energy, null]
|
| 92 |
+
- [hit_pandora_pfo_energy, null]
|
| 93 |
+
- [hit_chis, null]
|
| 94 |
+
- [part_px , null]
|
| 95 |
+
- [part_py , null]
|
| 96 |
+
- [part_pz , null]
|
| 97 |
+
- [part_vertex_x, null]
|
| 98 |
+
- [part_vertex_y, null]
|
| 99 |
+
- [part_vertex_z, null]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
pf_vectors:
|
| 103 |
+
length: 25000
|
| 104 |
+
pad_mode: wrap
|
| 105 |
+
vars:
|
| 106 |
+
- [hit_type, null] #0
|
| 107 |
+
- [hit_e_nn, null] #1
|
| 108 |
+
# #labels
|
| 109 |
+
# - [part_p1, null] #2
|
| 110 |
+
# - [part_theta1, null] #3
|
| 111 |
+
# - [part_phi1, null] #4
|
| 112 |
+
# - [part_m1, null] #15
|
| 113 |
+
# - [part_pid1, null] #6
|
| 114 |
+
pf_vectoronly:
|
| 115 |
+
length: 25000
|
| 116 |
+
pad_mode: wrap
|
| 117 |
+
vars:
|
| 118 |
+
- [hit_genlink0, null] # hit link to MC
|
| 119 |
+
- [hit_genlink1, null] # pandora_cluster if val data otherwise 0
|
| 120 |
+
- [hit_genlink2, null] # pandora_index_pfo if val data otherwise 0
|
| 121 |
+
- [hit_genlink3, null] # hit link to daugther
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
pf_mask:
|
| 125 |
+
length: 25000
|
| 126 |
+
pad_mode: constant
|
| 127 |
+
vars:
|
| 128 |
+
- [hit_mask, null]
|
| 129 |
+
- [part_mask, null]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
labels:
|
| 133 |
+
### type can be `simple`, `custom`
|
| 134 |
+
### [option 1] use `simple` for binary/multi-class classification, then `value` is a list of 0-1 labels
|
| 135 |
+
#type: simple
|
| 136 |
+
#value: [
|
| 137 |
+
# hit_ty
|
| 138 |
+
# ]
|
| 139 |
+
### [option 2] otherwise use `custom` to define the label, then `value` is a map
|
| 140 |
+
# type: custom
|
| 141 |
+
# value:
|
| 142 |
+
# target_mass: np.where(fj_isQCD, fj_genjet_sdmass, fj_gen_mass)
|
| 143 |
+
|
| 144 |
+
observers:
|
| 145 |
+
|
| 146 |
+
|
scripts/evaluation.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python -m src.train_lightning1 \
|
| 2 |
+
--data-test /eos/experiment/fcc/users/m/mgarciam/mlpf/CLD/train/Z_uds_CLD_o2_v05_eval_v1/05/pf_tree_10100.parquet \
|
| 3 |
+
--data-config config_files/config_hits_track_v4.yaml \
|
| 4 |
+
--network-config src/models/wrapper/example_mode_gatr_noise.py \
|
| 5 |
+
--model-prefix /eos/user/m/mgarciam/datasets_mlpf/models_trained_CLD/041225_arc_05/ \
|
| 6 |
+
--load-model-weights-clustering /eos/user/m/mgarciam/datasets_mlpf/models_trained_CLD/041225_arc_05/_epoch=9_step=120000.ckpt \
|
| 7 |
+
--load-model-weights /eos/user/m/mgarciam/datasets_mlpf/models_trained_CLD/040226_basic_ecor/_epoch=2_step=24000.ckpt \
|
| 8 |
+
--wandb-displayname eval_gun_drlog \
|
| 9 |
+
--gpus 2 \
|
| 10 |
+
--batch-size 20 \
|
| 11 |
+
--num-workers 4 \
|
| 12 |
+
--start-lr 1e-3 \
|
| 13 |
+
--num-epochs 100 \
|
| 14 |
+
--fetch-step 1 \
|
| 15 |
+
--fetch-by-files \
|
| 16 |
+
--log-wandb \
|
| 17 |
+
--wandb-projectname mlpf_debug_eval \
|
| 18 |
+
--wandb-entity fcc_ml \
|
| 19 |
+
--frac_cluster_loss 0 \
|
| 20 |
+
--qmin 1 \
|
| 21 |
+
--use-average-cc-pos 0.99 \
|
| 22 |
+
--correction \
|
| 23 |
+
--freeze-clustering \
|
| 24 |
+
--predict \
|
| 25 |
+
--name-output test_plot_hitpf2 \
|
| 26 |
+
--pandora
|
scripts/train_clustering.sh
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python -m src.train_lightning1 \
|
| 2 |
+
--data-train /eos/experiment/fcc/users/m/mgarciam/mlpf/CLD/train/Z_uds_clustering_dataset_3/05/ \
|
| 3 |
+
--data-config config_files/config_hits_track_v4.yaml \
|
| 4 |
+
--network-config src/models/wrapper/example_mode_gatr_noise.py \
|
| 5 |
+
--model-prefix /eos/user/m/mgarciam/datasets_mlpf/models_trained_CLD/test_hitpf/ \
|
| 6 |
+
--num-workers 4 \
|
| 7 |
+
--gpus 0,1 \
|
| 8 |
+
--batch-size 5 \
|
| 9 |
+
--num-epochs 100 \
|
| 10 |
+
--fetch-step 1 \
|
| 11 |
+
--log-wandb \
|
| 12 |
+
--wandb-displayname CLD_clustering_training \
|
| 13 |
+
--wandb-projectname mlpf_debug \
|
| 14 |
+
--wandb-entity ml4hep \
|
| 15 |
+
--frac_cluster_loss 0 \
|
| 16 |
+
--qmin 3 \
|
| 17 |
+
--use-average-cc-pos 0.98 \
|
| 18 |
+
--train-val-split 0.98 \
|
| 19 |
+
--fetch-by-files \
|
| 20 |
+
--train-batches 10
|
scripts/train_energy_pid.sh
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python -m src.train_lightning1 \
|
| 2 |
+
--data-train /eos/experiment/fcc/users/m/mgarciam/mlpf/CLD/train/gun_ecort/05/ \
|
| 3 |
+
--data-config config_files/config_hits_track_v4.yaml \
|
| 4 |
+
--network-config src/models/wrapper/example_mode_gatr_noise.py \
|
| 5 |
+
--model-prefix /eos/user/m/mgarciam/datasets_mlpf/models_trained_CLD/test_hitpf_ecor/ \
|
| 6 |
+
--wandb-displayname E_PID_05_basicecor_v1_1 \
|
| 7 |
+
--gpus 0 \
|
| 8 |
+
--batch-size 20 \
|
| 9 |
+
--num-workers 4 \
|
| 10 |
+
--start-lr 1e-3 \
|
| 11 |
+
--num-epochs 100 \
|
| 12 |
+
--fetch-step 1 \
|
| 13 |
+
--fetch-by-files \
|
| 14 |
+
--train-val-split 0.98 \
|
| 15 |
+
--train-batches 8000 \
|
| 16 |
+
--log-wandb \
|
| 17 |
+
--wandb-projectname mlpf_debug \
|
| 18 |
+
--wandb-entity ml4hep \
|
| 19 |
+
--frac_cluster_loss 0 \
|
| 20 |
+
--qmin 1 \
|
| 21 |
+
--use-average-cc-pos 0.99 \
|
| 22 |
+
--correction \
|
| 23 |
+
--freeze-clustering \
|
| 24 |
+
--use-gt-clusters
|
src/data/config.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import yaml
|
| 3 |
+
import copy
|
| 4 |
+
|
| 5 |
+
from src.logger.logger import _logger
|
| 6 |
+
from src.data.tools import _get_variable_names
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _as_list(x):
|
| 10 |
+
if x is None:
|
| 11 |
+
return None
|
| 12 |
+
elif isinstance(x, (list, tuple)):
|
| 13 |
+
return x
|
| 14 |
+
else:
|
| 15 |
+
return [x]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _md5(fname):
|
| 19 |
+
'''https://stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file'''
|
| 20 |
+
import hashlib
|
| 21 |
+
hash_md5 = hashlib.md5()
|
| 22 |
+
with open(fname, "rb") as f:
|
| 23 |
+
for chunk in iter(lambda: f.read(4096), b""):
|
| 24 |
+
hash_md5.update(chunk)
|
| 25 |
+
return hash_md5.hexdigest()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DataConfig(object):
|
| 29 |
+
r"""Data loading configuration.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, print_info=True, **kwargs):
|
| 33 |
+
opts = {
|
| 34 |
+
'treename': None,
|
| 35 |
+
'selection': None,
|
| 36 |
+
'test_time_selection': None,
|
| 37 |
+
'preprocess': {'method': 'manual', 'data_fraction': 0.1, 'params': None},
|
| 38 |
+
'new_variables': {},
|
| 39 |
+
'inputs': {},
|
| 40 |
+
'labels': {},
|
| 41 |
+
'observers': [],
|
| 42 |
+
'monitor_variables': [],
|
| 43 |
+
'weights': None,
|
| 44 |
+
'graph_config': {},
|
| 45 |
+
'custom_model_kwargs': {}
|
| 46 |
+
}
|
| 47 |
+
for k, v in kwargs.items():
|
| 48 |
+
if v is not None:
|
| 49 |
+
if isinstance(opts[k], dict):
|
| 50 |
+
opts[k].update(v)
|
| 51 |
+
else:
|
| 52 |
+
opts[k] = v
|
| 53 |
+
# only information in ``self.options'' will be persisted when exporting to YAML
|
| 54 |
+
self.options = opts
|
| 55 |
+
if print_info:
|
| 56 |
+
_logger.debug(opts)
|
| 57 |
+
|
| 58 |
+
self.selection = opts['selection']
|
| 59 |
+
self.test_time_selection = opts['test_time_selection'] if opts['test_time_selection'] else self.selection
|
| 60 |
+
self.var_funcs = copy.deepcopy(opts['new_variables'])
|
| 61 |
+
# preprocessing config
|
| 62 |
+
self.preprocess = opts['preprocess']
|
| 63 |
+
self._auto_standardization = opts['preprocess']['method'].lower().startswith('auto')
|
| 64 |
+
self._missing_standardization_info = False
|
| 65 |
+
self.preprocess_params = opts['preprocess']['params'] if opts['preprocess']['params'] is not None else {}
|
| 66 |
+
# inputs
|
| 67 |
+
self.input_names = tuple(opts['inputs'].keys())
|
| 68 |
+
self.input_dicts = {k: [] for k in self.input_names}
|
| 69 |
+
self.input_shapes = {}
|
| 70 |
+
for k, o in opts['inputs'].items():
|
| 71 |
+
self.input_shapes[k] = (-1, len(o['vars']), o['length'])
|
| 72 |
+
for v in o['vars']:
|
| 73 |
+
v = _as_list(v)
|
| 74 |
+
self.input_dicts[k].append(v[0])
|
| 75 |
+
|
| 76 |
+
if opts['preprocess']['params'] is None:
|
| 77 |
+
|
| 78 |
+
def _get(idx, default):
|
| 79 |
+
try:
|
| 80 |
+
return v[idx]
|
| 81 |
+
except IndexError:
|
| 82 |
+
return default
|
| 83 |
+
|
| 84 |
+
params = {'length': o['length'], 'pad_mode': o.get('pad_mode', 'constant').lower(),
|
| 85 |
+
'center': _get(1, 'auto' if self._auto_standardization else None),
|
| 86 |
+
'scale': _get(2, 1), 'min': _get(3, -5), 'max': _get(4, 5), 'pad_value': _get(5, 0)}
|
| 87 |
+
|
| 88 |
+
if v[0] in self.preprocess_params and params != self.preprocess_params[v[0]]:
|
| 89 |
+
raise RuntimeError(
|
| 90 |
+
'Incompatible info for variable %s, had: \n %s\nnow got:\n %s' %
|
| 91 |
+
(v[0], str(self.preprocess_params[v[0]]), str(params)))
|
| 92 |
+
if k.endswith('_mask') and params['pad_mode'] != 'constant':
|
| 93 |
+
raise RuntimeError('The `pad_mode` must be set to `constant` for the mask input `%s`' % k)
|
| 94 |
+
if params['center'] == 'auto':
|
| 95 |
+
self._missing_standardization_info = True
|
| 96 |
+
self.preprocess_params[v[0]] = params
|
| 97 |
+
|
| 98 |
+
# observers
|
| 99 |
+
self.observer_names = tuple(opts['observers'])
|
| 100 |
+
# monitor variables
|
| 101 |
+
self.monitor_variables = tuple(opts['monitor_variables'])
|
| 102 |
+
# Z variables: returned as `Z` in the dataloader (use monitor_variables for training, observers for eval)
|
| 103 |
+
self.z_variables = self.observer_names if len(self.observer_names) > 0 else self.monitor_variables
|
| 104 |
+
|
| 105 |
+
# remove self mapping from var_funcs
|
| 106 |
+
for k, v in self.var_funcs.items():
|
| 107 |
+
if k == v:
|
| 108 |
+
del self.var_funcs[k]
|
| 109 |
+
|
| 110 |
+
if print_info:
|
| 111 |
+
def _log(msg, *args, **kwargs):
|
| 112 |
+
_logger.info(msg, *args, color='lightgray', **kwargs)
|
| 113 |
+
_log('preprocess config: %s', str(self.preprocess))
|
| 114 |
+
_log('selection: %s', str(self.selection))
|
| 115 |
+
_log('test_time_selection: %s', str(self.test_time_selection))
|
| 116 |
+
_log('var_funcs:\n - %s', '\n - '.join(str(it) for it in self.var_funcs.items()))
|
| 117 |
+
_log('input_names: %s', str(self.input_names))
|
| 118 |
+
_log('input_dicts:\n - %s', '\n - '.join(str(it) for it in self.input_dicts.items()))
|
| 119 |
+
_log('input_shapes:\n - %s', '\n - '.join(str(it) for it in self.input_shapes.items()))
|
| 120 |
+
_log('preprocess_params:\n - %s', '\n - '.join(str(it) for it in self.preprocess_params.items()))
|
| 121 |
+
#_log('label_names: %s', str(self.label_names))
|
| 122 |
+
_log('observer_names: %s', str(self.observer_names))
|
| 123 |
+
_log('monitor_variables: %s', str(self.monitor_variables))
|
| 124 |
+
if opts['weights'] is not None:
|
| 125 |
+
if self.use_precomputed_weights:
|
| 126 |
+
_log('weight: %s' % self.var_funcs[self.weight_name])
|
| 127 |
+
else:
|
| 128 |
+
for k in ['reweight_method', 'reweight_basewgt', 'reweight_branches', 'reweight_bins',
|
| 129 |
+
'reweight_classes', 'class_weights', 'reweight_threshold',
|
| 130 |
+
'reweight_discard_under_overflow']:
|
| 131 |
+
_log('%s: %s' % (k, getattr(self, k)))
|
| 132 |
+
|
| 133 |
+
# parse config
|
| 134 |
+
self.keep_branches = set()
|
| 135 |
+
aux_branches = set()
|
| 136 |
+
# selection
|
| 137 |
+
if self.selection:
|
| 138 |
+
aux_branches.update(_get_variable_names(self.selection))
|
| 139 |
+
# test time selection
|
| 140 |
+
if self.test_time_selection:
|
| 141 |
+
aux_branches.update(_get_variable_names(self.test_time_selection))
|
| 142 |
+
# var_funcs
|
| 143 |
+
self.keep_branches.update(self.var_funcs.keys())
|
| 144 |
+
for expr in self.var_funcs.values():
|
| 145 |
+
aux_branches.update(_get_variable_names(expr))
|
| 146 |
+
# inputs
|
| 147 |
+
for names in self.input_dicts.values():
|
| 148 |
+
self.keep_branches.update(names)
|
| 149 |
+
# labels
|
| 150 |
+
#self.keep_branches.update(self.label_names)
|
| 151 |
+
# weight
|
| 152 |
+
#if self.weight_name:
|
| 153 |
+
# self.keep_branches.add(self.weight_name)
|
| 154 |
+
# if not self.use_precomputed_weights:
|
| 155 |
+
# aux_branches.update(self.reweight_branches)
|
| 156 |
+
# aux_branches.update(self.reweight_classes)
|
| 157 |
+
# observers
|
| 158 |
+
self.keep_branches.update(self.observer_names)
|
| 159 |
+
# monitor variables
|
| 160 |
+
self.keep_branches.update(self.monitor_variables)
|
| 161 |
+
# keep and drop
|
| 162 |
+
self.drop_branches = (aux_branches - self.keep_branches)
|
| 163 |
+
self.load_branches = (aux_branches | self.keep_branches) - set(self.var_funcs.keys()) #- {self.weight_name, }
|
| 164 |
+
if print_info:
|
| 165 |
+
_logger.debug('drop_branches:\n %s', ','.join(self.drop_branches))
|
| 166 |
+
_logger.debug('load_branches:\n %s', ','.join(self.load_branches))
|
| 167 |
+
|
| 168 |
+
def __getattr__(self, name):
|
| 169 |
+
return self.options[name]
|
| 170 |
+
|
| 171 |
+
def dump(self, fp):
|
| 172 |
+
with open(fp, 'w') as f:
|
| 173 |
+
yaml.safe_dump(self.options, f, sort_keys=False)
|
| 174 |
+
|
| 175 |
+
@classmethod
|
| 176 |
+
def load(cls, fp, load_observers=True, load_reweight_info=True, extra_selection=None, extra_test_selection=None):
|
| 177 |
+
with open(fp) as f:
|
| 178 |
+
options = yaml.safe_load(f)
|
| 179 |
+
if not load_observers:
|
| 180 |
+
options['observers'] = None
|
| 181 |
+
if not load_reweight_info:
|
| 182 |
+
options['weights'] = None
|
| 183 |
+
if extra_selection:
|
| 184 |
+
options['selection'] = '(%s) & (%s)' % (options['selection'], extra_selection)
|
| 185 |
+
if extra_test_selection:
|
| 186 |
+
if 'test_time_selection' not in options:
|
| 187 |
+
raise RuntimeError('`test_time_selection` is not defined in the yaml file!')
|
| 188 |
+
options['test_time_selection'] = '(%s) & (%s)' % (options['test_time_selection'], extra_test_selection)
|
| 189 |
+
return cls(**options)
|
| 190 |
+
|
| 191 |
+
def copy(self):
|
| 192 |
+
return self.__class__(print_info=False, **copy.deepcopy(self.options))
|
| 193 |
+
|
| 194 |
+
def __copy__(self):
|
| 195 |
+
return self.copy()
|
| 196 |
+
|
| 197 |
+
def __deepcopy__(self, memo):
|
| 198 |
+
return self.copy()
|
| 199 |
+
|
| 200 |
+
def export_json(self, fp):
|
| 201 |
+
import json
|
| 202 |
+
j = {'output_names': self.label_value, 'input_names': self.input_names}
|
| 203 |
+
for k, v in self.input_dicts.items():
|
| 204 |
+
j[k] = {'var_names': v, 'var_infos': {}}
|
| 205 |
+
for var_name in v:
|
| 206 |
+
j[k]['var_length'] = self.preprocess_params[var_name]['length']
|
| 207 |
+
info = self.preprocess_params[var_name]
|
| 208 |
+
j[k]['var_infos'][var_name] = {
|
| 209 |
+
'median': 0 if info['center'] is None else info['center'],
|
| 210 |
+
'norm_factor': info['scale'],
|
| 211 |
+
'replace_inf_value': 0,
|
| 212 |
+
'lower_bound': -1e32 if info['center'] is None else info['min'],
|
| 213 |
+
'upper_bound': 1e32 if info['center'] is None else info['max'],
|
| 214 |
+
'pad': info['pad_value']
|
| 215 |
+
}
|
| 216 |
+
with open(fp, 'w') as f:
|
| 217 |
+
json.dump(j, f, indent=2)
|
| 218 |
+
|
src/data/fileio.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import awkward as ak
|
| 3 |
+
import tqdm
|
| 4 |
+
import traceback
|
| 5 |
+
from src.data.tools import _concat, _concat_records
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _read_hdf5(filepath, branches, load_range=None):
|
| 10 |
+
import tables
|
| 11 |
+
tables.set_blosc_max_threads(4)
|
| 12 |
+
with tables.open_file(filepath) as f:
|
| 13 |
+
outputs = {k: getattr(f.root, k)[:] for k in branches}
|
| 14 |
+
if load_range is None:
|
| 15 |
+
load_range = (0, 1)
|
| 16 |
+
start = math.trunc(load_range[0] * len(outputs[branches[0]]))
|
| 17 |
+
stop = max(start + 1, math.trunc(load_range[1] * len(outputs[branches[0]])))
|
| 18 |
+
for k, v in outputs.items():
|
| 19 |
+
outputs[k] = v[start:stop]
|
| 20 |
+
return ak.Array(outputs)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _read_root(filepath, branches, load_range=None, treename=None):
|
| 24 |
+
import uproot
|
| 25 |
+
with uproot.open(filepath) as f:
|
| 26 |
+
if treename is None:
|
| 27 |
+
treenames = set([k.split(';')[0] for k, v in f.items() if getattr(v, 'classname', '') == 'TTree'])
|
| 28 |
+
if len(treenames) == 1:
|
| 29 |
+
treename = treenames.pop()
|
| 30 |
+
else:
|
| 31 |
+
raise RuntimeError(
|
| 32 |
+
'Need to specify `treename` as more than one trees are found in file %s: %s' %
|
| 33 |
+
(filepath, str(branches)))
|
| 34 |
+
tree = f[treename]
|
| 35 |
+
if load_range is not None:
|
| 36 |
+
start = math.trunc(load_range[0] * tree.num_entries)
|
| 37 |
+
stop = max(start + 1, math.trunc(load_range[1] * tree.num_entries))
|
| 38 |
+
else:
|
| 39 |
+
start, stop = None, None
|
| 40 |
+
outputs = tree.arrays(filter_name=branches, entry_start=start, entry_stop=stop)
|
| 41 |
+
return outputs
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _read_awkd(filepath, branches, load_range=None):
|
| 45 |
+
import awkward0
|
| 46 |
+
with awkward0.load(filepath) as f:
|
| 47 |
+
outputs = {k: f[k] for k in branches}
|
| 48 |
+
if load_range is None:
|
| 49 |
+
load_range = (0, 1)
|
| 50 |
+
start = math.trunc(load_range[0] * len(outputs[branches[0]]))
|
| 51 |
+
stop = max(start + 1, math.trunc(load_range[1] * len(outputs[branches[0]])))
|
| 52 |
+
for k, v in outputs.items():
|
| 53 |
+
outputs[k] = ak.from_awkward0(v[start:stop])
|
| 54 |
+
return ak.Array(outputs)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _slice_record(record, start, stop):
|
| 58 |
+
sliced_fields = {}
|
| 59 |
+
for field in record.fields:
|
| 60 |
+
sliced_fields[field] = record[field][start:stop]
|
| 61 |
+
return ak.Record(sliced_fields)
|
| 62 |
+
|
| 63 |
+
def _read_parquet(filepath, load_range=None):
|
| 64 |
+
outputs = ak.from_parquet(filepath)
|
| 65 |
+
len_outputs = len(outputs["X_track"])
|
| 66 |
+
if load_range is not None:
|
| 67 |
+
start = math.trunc(load_range[0] * len_outputs)
|
| 68 |
+
stop = max(start + 1, math.trunc(load_range[1] * len_outputs))
|
| 69 |
+
outputs = _slice_record(outputs, start, stop)
|
| 70 |
+
|
| 71 |
+
return outputs
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _read_files(filelist, load_range=None, show_progressbar=False, **kwargs):
|
| 75 |
+
import os
|
| 76 |
+
table = []
|
| 77 |
+
if show_progressbar:
|
| 78 |
+
filelist = tqdm.tqdm(filelist)
|
| 79 |
+
for filepath in filelist:
|
| 80 |
+
ext = os.path.splitext(filepath)[1]
|
| 81 |
+
if ext not in ('.h5', '.root', '.awkd', '.parquet'):
|
| 82 |
+
raise RuntimeError('File %s of type `%s` is not supported!' % (filepath, ext))
|
| 83 |
+
a = _read_parquet(filepath, load_range=load_range)
|
| 84 |
+
if a is not None:
|
| 85 |
+
table.append(a)
|
| 86 |
+
table = _concat_records(table) # ak.Array
|
| 87 |
+
if len(table["X_track"]) == 0:
|
| 88 |
+
raise RuntimeError(f'Zero entries loaded when reading files {filelist} with `load_range`={load_range}.')
|
| 89 |
+
return table
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _write_root(file, table, treename='Events', compression=-1, step=1048576):
|
| 93 |
+
import uproot
|
| 94 |
+
if compression == -1:
|
| 95 |
+
compression = uproot.LZ4(4)
|
| 96 |
+
with uproot.recreate(file, compression=compression) as fout:
|
| 97 |
+
tree = fout.mktree(treename, {k: v.dtype for k, v in table.items()})
|
| 98 |
+
start = 0
|
| 99 |
+
while start < len(list(table.values())[0]) - 1:
|
| 100 |
+
tree.extend({k: v[start:start + step] for k, v in table.items()})
|
| 101 |
+
start += step
|
src/data/preprocess.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import glob
|
| 3 |
+
import copy
|
| 4 |
+
import numpy as np
|
| 5 |
+
import awkward as ak
|
| 6 |
+
|
| 7 |
+
from src.data.tools import _get_variable_names, _eval_expr
|
| 8 |
+
from src.data.fileio import _read_files
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _apply_selection(table, selection):
|
| 12 |
+
if selection is None:
|
| 13 |
+
return table
|
| 14 |
+
selected = ak.values_astype(_eval_expr(selection, table), 'bool')
|
| 15 |
+
return table[selected]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _build_new_variables(table, funcs):
|
| 19 |
+
if funcs is None:
|
| 20 |
+
return table
|
| 21 |
+
for k, expr in funcs.items():
|
| 22 |
+
if k in table.fields:
|
| 23 |
+
continue
|
| 24 |
+
table[k] = _eval_expr(expr, table)
|
| 25 |
+
return table
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _clean_up(table, drop_branches):
|
| 29 |
+
columns = [k for k in table.fields if k not in drop_branches]
|
| 30 |
+
return table[columns]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _build_weights(table, data_config, reweight_hists=None):
|
| 34 |
+
if data_config.weight_name is None:
|
| 35 |
+
raise RuntimeError('Error when building weights: `weight_name` is None!')
|
| 36 |
+
if data_config.use_precomputed_weights:
|
| 37 |
+
return ak.to_numpy(table[data_config.weight_name])
|
| 38 |
+
else:
|
| 39 |
+
x_var, y_var = data_config.reweight_branches
|
| 40 |
+
x_bins, y_bins = data_config.reweight_bins
|
| 41 |
+
rwgt_sel = None
|
| 42 |
+
if data_config.reweight_discard_under_overflow:
|
| 43 |
+
rwgt_sel = (table[x_var] >= min(x_bins)) & (table[x_var] <= max(x_bins)) & \
|
| 44 |
+
(table[y_var] >= min(y_bins)) & (table[y_var] <= max(y_bins))
|
| 45 |
+
# init w/ wgt=0: events not belonging to any class in `reweight_classes` will get a weight of 0 at the end
|
| 46 |
+
wgt = np.zeros(len(table), dtype='float32')
|
| 47 |
+
sum_evts = 0
|
| 48 |
+
if reweight_hists is None:
|
| 49 |
+
reweight_hists = data_config.reweight_hists
|
| 50 |
+
for label, hist in reweight_hists.items():
|
| 51 |
+
pos = table[label] == 1
|
| 52 |
+
if rwgt_sel is not None:
|
| 53 |
+
pos = (pos & rwgt_sel)
|
| 54 |
+
rwgt_x_vals = ak.to_numpy(table[x_var][pos])
|
| 55 |
+
rwgt_y_vals = ak.to_numpy(table[y_var][pos])
|
| 56 |
+
x_indices = np.clip(np.digitize(
|
| 57 |
+
rwgt_x_vals, x_bins) - 1, a_min=0, a_max=len(x_bins) - 2)
|
| 58 |
+
y_indices = np.clip(np.digitize(
|
| 59 |
+
rwgt_y_vals, y_bins) - 1, a_min=0, a_max=len(y_bins) - 2)
|
| 60 |
+
wgt[pos] = hist[x_indices, y_indices]
|
| 61 |
+
sum_evts += np.sum(pos)
|
| 62 |
+
if sum_evts != len(table):
|
| 63 |
+
warn(
|
| 64 |
+
'Not all selected events used in the reweighting. '
|
| 65 |
+
'Check consistency between `selection` and `reweight_classes` definition, or with the `reweight_vars` binnings '
|
| 66 |
+
'(under- and overflow bins are discarded by default, unless `reweight_discard_under_overflow` is set to `False` in the `weights` section).',
|
| 67 |
+
)
|
| 68 |
+
if data_config.reweight_basewgt:
|
| 69 |
+
wgt *= ak.to_numpy(table[data_config.basewgt_name])
|
| 70 |
+
return wgt
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class AutoStandardizer(object):
|
| 74 |
+
r"""AutoStandardizer.
|
| 75 |
+
Class to compute the variable standardization information.
|
| 76 |
+
Arguments:
|
| 77 |
+
filelist (list): list of files to be loaded.
|
| 78 |
+
data_config (DataConfig): object containing data format information.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(self, filelist, data_config):
|
| 82 |
+
if isinstance(filelist, dict):
|
| 83 |
+
filelist = sum(filelist.values(), [])
|
| 84 |
+
self._filelist = filelist if isinstance(
|
| 85 |
+
filelist, (list, tuple)) else glob.glob(filelist)
|
| 86 |
+
self._data_config = data_config.copy()
|
| 87 |
+
self.load_range = (0, data_config.preprocess.get('data_fraction', 0.1))
|
| 88 |
+
|
| 89 |
+
def read_file(self, filelist):
|
| 90 |
+
self.keep_branches = set()
|
| 91 |
+
self.load_branches = set()
|
| 92 |
+
for k, params in self._data_config.preprocess_params.items():
|
| 93 |
+
if params['center'] == 'auto':
|
| 94 |
+
self.keep_branches.add(k)
|
| 95 |
+
if k in self._data_config.var_funcs:
|
| 96 |
+
expr = self._data_config.var_funcs[k]
|
| 97 |
+
self.load_branches.update(_get_variable_names(expr))
|
| 98 |
+
else:
|
| 99 |
+
self.load_branches.add(k)
|
| 100 |
+
if self._data_config.selection:
|
| 101 |
+
self.load_branches.update(_get_variable_names(self._data_config.selection))
|
| 102 |
+
|
| 103 |
+
table = _read_files(filelist, self.load_branches, self.load_range,
|
| 104 |
+
show_progressbar=True, treename=self._data_config.treename)
|
| 105 |
+
table = _apply_selection(table, self._data_config.selection)
|
| 106 |
+
table = _build_new_variables(
|
| 107 |
+
table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches})
|
| 108 |
+
table = _clean_up(table, self.load_branches - self.keep_branches)
|
| 109 |
+
return table
|
| 110 |
+
|
| 111 |
+
def make_preprocess_params(self, table):
|
| 112 |
+
|
| 113 |
+
preprocess_params = copy.deepcopy(self._data_config.preprocess_params)
|
| 114 |
+
for k, params in self._data_config.preprocess_params.items():
|
| 115 |
+
if params['center'] == 'auto':
|
| 116 |
+
if k.endswith('_mask'):
|
| 117 |
+
params['center'] = None
|
| 118 |
+
else:
|
| 119 |
+
a = ak.to_numpy(ak.flatten(table[k], axis=None))
|
| 120 |
+
# check for NaN
|
| 121 |
+
if np.any(np.isnan(a)):
|
| 122 |
+
|
| 123 |
+
time.sleep(10)
|
| 124 |
+
a = np.nan_to_num(a)
|
| 125 |
+
low, center, high = np.percentile(a, [16, 50, 84])
|
| 126 |
+
scale = max(high - center, center - low)
|
| 127 |
+
scale = 1 if scale == 0 else 1. / scale
|
| 128 |
+
params['center'] = float(center)
|
| 129 |
+
params['scale'] = float(scale)
|
| 130 |
+
preprocess_params[k] = params
|
| 131 |
+
|
| 132 |
+
return preprocess_params
|
| 133 |
+
|
| 134 |
+
def produce(self, output=None):
|
| 135 |
+
table = self.read_file(self._filelist)
|
| 136 |
+
preprocess_params = self.make_preprocess_params(table)
|
| 137 |
+
self._data_config.preprocess_params = preprocess_params
|
| 138 |
+
# must also propogate the changes to `data_config.options` so it can be persisted
|
| 139 |
+
self._data_config.options['preprocess']['params'] = preprocess_params
|
| 140 |
+
if output:
|
| 141 |
+
self._data_config.dump(output)
|
| 142 |
+
return self._data_config
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class WeightMaker(object):
|
| 146 |
+
r"""WeightMaker.
|
| 147 |
+
Class to make reweighting information.
|
| 148 |
+
Arguments:
|
| 149 |
+
filelist (list): list of files to be loaded.
|
| 150 |
+
data_config (DataConfig): object containing data format information.
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
def __init__(self, filelist, data_config):
|
| 154 |
+
if isinstance(filelist, dict):
|
| 155 |
+
filelist = sum(filelist.values(), [])
|
| 156 |
+
self._filelist = filelist if isinstance(filelist, (list, tuple)) else glob.glob(filelist)
|
| 157 |
+
self._data_config = data_config.copy()
|
| 158 |
+
|
| 159 |
+
def read_file(self, filelist):
|
| 160 |
+
self.keep_branches = set(self._data_config.reweight_branches + self._data_config.reweight_classes +
|
| 161 |
+
(self._data_config.basewgt_name,))
|
| 162 |
+
self.load_branches = set()
|
| 163 |
+
for k in self.keep_branches:
|
| 164 |
+
if k in self._data_config.var_funcs:
|
| 165 |
+
expr = self._data_config.var_funcs[k]
|
| 166 |
+
self.load_branches.update(_get_variable_names(expr))
|
| 167 |
+
else:
|
| 168 |
+
self.load_branches.add(k)
|
| 169 |
+
if self._data_config.selection:
|
| 170 |
+
self.load_branches.update(_get_variable_names(self._data_config.selection))
|
| 171 |
+
table = _read_files(filelist, self.load_branches, show_progressbar=True, treename=self._data_config.treename)
|
| 172 |
+
table = _apply_selection(table, self._data_config.selection)
|
| 173 |
+
table = _build_new_variables(
|
| 174 |
+
table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches})
|
| 175 |
+
table = _clean_up(table, self.load_branches - self.keep_branches)
|
| 176 |
+
return table
|
| 177 |
+
|
| 178 |
+
def make_weights(self, table):
|
| 179 |
+
x_var, y_var = self._data_config.reweight_branches
|
| 180 |
+
x_bins, y_bins = self._data_config.reweight_bins
|
| 181 |
+
if not self._data_config.reweight_discard_under_overflow:
|
| 182 |
+
# clip variables to be within bin ranges
|
| 183 |
+
x_min, x_max = min(x_bins), max(x_bins)
|
| 184 |
+
y_min, y_max = min(y_bins), max(y_bins)
|
| 185 |
+
table[x_var] = np.clip(table[x_var], min(x_bins), max(x_bins))
|
| 186 |
+
table[y_var] = np.clip(table[y_var], min(y_bins), max(y_bins))
|
| 187 |
+
sum_evts = 0
|
| 188 |
+
max_weight = 0.9
|
| 189 |
+
raw_hists = {}
|
| 190 |
+
class_events = {}
|
| 191 |
+
result = {}
|
| 192 |
+
for label in self._data_config.reweight_classes:
|
| 193 |
+
pos = (table[label] == 1)
|
| 194 |
+
x = ak.to_numpy(table[x_var][pos])
|
| 195 |
+
y = ak.to_numpy(table[y_var][pos])
|
| 196 |
+
hist, _, _ = np.histogram2d(x, y, bins=self._data_config.reweight_bins)
|
| 197 |
+
sum_evts += hist.sum()
|
| 198 |
+
if self._data_config.reweight_basewgt:
|
| 199 |
+
w = ak.to_numpy(table[self._data_config.basewgt_name][pos])
|
| 200 |
+
hist, _, _ = np.histogram2d(x, y, weights=w, bins=self._data_config.reweight_bins)
|
| 201 |
+
|
| 202 |
+
raw_hists[label] = hist.astype('float32')
|
| 203 |
+
result[label] = hist.astype('float32')
|
| 204 |
+
if sum_evts != len(table):
|
| 205 |
+
time.sleep(10)
|
| 206 |
+
|
| 207 |
+
if self._data_config.reweight_method == 'flat':
|
| 208 |
+
for label, classwgt in zip(self._data_config.reweight_classes, self._data_config.class_weights):
|
| 209 |
+
hist = result[label]
|
| 210 |
+
threshold_ = np.median(hist[hist > 0]) * 0.01
|
| 211 |
+
nonzero_vals = hist[hist > threshold_]
|
| 212 |
+
min_val, med_val = np.min(nonzero_vals), np.median(hist) # not really used
|
| 213 |
+
ref_val = np.percentile(nonzero_vals, self._data_config.reweight_threshold)
|
| 214 |
+
# wgt: bins w/ 0 elements will get a weight of 0; bins w/ content<ref_val will get 1
|
| 215 |
+
wgt = np.clip(np.nan_to_num(ref_val / hist, posinf=0), 0, 1)
|
| 216 |
+
result[label] = wgt
|
| 217 |
+
# divide by classwgt here will effective increase the weight later
|
| 218 |
+
class_events[label] = np.sum(raw_hists[label] * wgt) / classwgt
|
| 219 |
+
elif self._data_config.reweight_method == 'ref':
|
| 220 |
+
# use class 0 as the reference
|
| 221 |
+
hist_ref = raw_hists[self._data_config.reweight_classes[0]]
|
| 222 |
+
for label, classwgt in zip(self._data_config.reweight_classes, self._data_config.class_weights):
|
| 223 |
+
# wgt: bins w/ 0 elements will get a weight of 0; bins w/ content<ref_val will get 1
|
| 224 |
+
ratio = np.nan_to_num(hist_ref / result[label], posinf=0)
|
| 225 |
+
upper = np.percentile(ratio[ratio > 0], 100 - self._data_config.reweight_threshold)
|
| 226 |
+
wgt = np.clip(ratio / upper, 0, 1) # -> [0,1]
|
| 227 |
+
result[label] = wgt
|
| 228 |
+
# divide by classwgt here will effective increase the weight later
|
| 229 |
+
class_events[label] = np.sum(raw_hists[label] * wgt) / classwgt
|
| 230 |
+
# ''equalize'' all classes
|
| 231 |
+
# multiply by max_weight (<1) to add some randomness in the sampling
|
| 232 |
+
min_nevt = min(class_events.values()) * max_weight
|
| 233 |
+
for label in self._data_config.reweight_classes:
|
| 234 |
+
class_wgt = float(min_nevt) / class_events[label]
|
| 235 |
+
result[label] *= class_wgt
|
| 236 |
+
|
| 237 |
+
if self._data_config.reweight_basewgt:
|
| 238 |
+
wgts = _build_weights(table, self._data_config, reweight_hists=result)
|
| 239 |
+
wgt_ref = np.percentile(wgts, 100 - self._data_config.reweight_threshold)
|
| 240 |
+
for label in self._data_config.reweight_classes:
|
| 241 |
+
result[label] /= wgt_ref
|
| 242 |
+
|
| 243 |
+
return result
|
| 244 |
+
|
| 245 |
+
def produce(self, output=None):
|
| 246 |
+
table = self.read_file(self._filelist)
|
| 247 |
+
wgts = self.make_weights(table)
|
| 248 |
+
self._data_config.reweight_hists = wgts
|
| 249 |
+
# must also propogate the changes to `data_config.options` so it can be persisted
|
| 250 |
+
self._data_config.options['weights']['reweight_hists'] = {k: v.tolist() for k, v in wgts.items()}
|
| 251 |
+
if output:
|
| 252 |
+
self._data_config.dump(output)
|
| 253 |
+
return self._data_config
|
src/data/tools.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import awkward as ak
|
| 5 |
+
|
| 6 |
+
def build_dummy_array(num, dtype=np.int64):
|
| 7 |
+
return ak.Array(
|
| 8 |
+
ak.contents.ListOffsetArray(
|
| 9 |
+
ak.index.Index64(np.zeros(num + 1, dtype=np.int64)),
|
| 10 |
+
ak.from_numpy(np.array([], dtype=dtype), highlevel=False),
|
| 11 |
+
)
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
def _concat_records(table):
|
| 15 |
+
table1 = {k : ak.from_iter([record[k][event] for record in table for event in range(len(record[k])) ]) for k in table[0].fields}
|
| 16 |
+
for k in table1.keys():
|
| 17 |
+
if len(ak.flatten(table1[k])) == 0:
|
| 18 |
+
table1[k] = build_dummy_array(len(table1[k]), np.float32)
|
| 19 |
+
table1 = ak.Record(table1)
|
| 20 |
+
return table1
|
| 21 |
+
|
| 22 |
+
def _concat(arrays, axis=0):
|
| 23 |
+
if len(arrays) == 0:
|
| 24 |
+
return np.array([])
|
| 25 |
+
if isinstance(arrays[0], np.ndarray):
|
| 26 |
+
return np.concatenate(arrays, axis=axis)
|
| 27 |
+
else:
|
| 28 |
+
return ak.concatenate(arrays, axis=axis)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _stack(arrays, axis=1):
|
| 32 |
+
if len(arrays) == 0:
|
| 33 |
+
return np.array([])
|
| 34 |
+
if isinstance(arrays[0], np.ndarray):
|
| 35 |
+
return np.stack(arrays, axis=axis)
|
| 36 |
+
else:
|
| 37 |
+
return ak.concatenate(arrays, axis=axis)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _pad_vector(a, value=-1, dtype="float32"):
|
| 41 |
+
maxlen = 2000
|
| 42 |
+
maxlen2 = 5
|
| 43 |
+
|
| 44 |
+
x = (np.ones((len(a), maxlen, maxlen2)) * value).astype(dtype)
|
| 45 |
+
for idx, s in enumerate(a):
|
| 46 |
+
for idx_vec, s_vec in enumerate(s):
|
| 47 |
+
x[idx, idx_vec, : len(s_vec)] = s_vec
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _pad(a, maxlen, value=0, dtype="float32"):
|
| 52 |
+
if isinstance(a, np.ndarray) and a.ndim >= 2 and a.shape[1] == maxlen:
|
| 53 |
+
return a
|
| 54 |
+
elif isinstance(a, ak.Array):
|
| 55 |
+
if a.ndim == 1:
|
| 56 |
+
a = ak.unflatten(a, 1)
|
| 57 |
+
a = ak.fill_none(ak.pad_none(a, maxlen, clip=True), value)
|
| 58 |
+
return ak.values_astype(a, dtype)
|
| 59 |
+
else:
|
| 60 |
+
x = (np.ones((len(a), maxlen)) * value).astype(dtype)
|
| 61 |
+
for idx, s in enumerate(a):
|
| 62 |
+
if not len(s):
|
| 63 |
+
continue
|
| 64 |
+
trunc = s[:maxlen].astype(dtype)
|
| 65 |
+
x[idx, : len(trunc)] = trunc
|
| 66 |
+
return x
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _repeat_pad(a, maxlen, shuffle=False, dtype="float32"):
|
| 70 |
+
x = ak.to_numpy(ak.flatten(a))
|
| 71 |
+
x = np.tile(x, int(np.ceil(len(a) * maxlen / len(x))))
|
| 72 |
+
if shuffle:
|
| 73 |
+
np.random.shuffle(x)
|
| 74 |
+
x = x[: len(a) * maxlen].reshape((len(a), maxlen))
|
| 75 |
+
mask = _pad(ak.zeros_like(a), maxlen, value=1)
|
| 76 |
+
x = _pad(a, maxlen) + mask * x
|
| 77 |
+
return ak.values_astype(x, dtype)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _clip(a, a_min, a_max):
|
| 81 |
+
try:
|
| 82 |
+
return np.clip(a, a_min, a_max)
|
| 83 |
+
except ValueError:
|
| 84 |
+
return ak.unflatten(np.clip(ak.flatten(a), a_min, a_max), ak.num(a))
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _knn(support, query, k, n_jobs=1):
|
| 88 |
+
from scipy.spatial import cKDTree
|
| 89 |
+
|
| 90 |
+
kdtree = cKDTree(support)
|
| 91 |
+
d, idx = kdtree.query(query, k, n_jobs=n_jobs)
|
| 92 |
+
return idx
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _batch_knn(supports, queries, k, maxlen_s, maxlen_q=None, n_jobs=1):
|
| 96 |
+
assert len(supports) == len(queries)
|
| 97 |
+
if maxlen_q is None:
|
| 98 |
+
maxlen_q = maxlen_s
|
| 99 |
+
batch_knn_idx = np.ones((len(supports), maxlen_q, k), dtype="int32") * (
|
| 100 |
+
maxlen_s - 1
|
| 101 |
+
)
|
| 102 |
+
for i, (s, q) in enumerate(zip(supports, queries)):
|
| 103 |
+
batch_knn_idx[i, : len(q[:maxlen_q]), :] = _knn(
|
| 104 |
+
s[:maxlen_s], q[:maxlen_q], k, n_jobs=n_jobs
|
| 105 |
+
).reshape(
|
| 106 |
+
(-1, k)
|
| 107 |
+
) # (len(q), k)
|
| 108 |
+
return batch_knn_idx
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _batch_permute_indices(array, maxlen):
|
| 112 |
+
batch_permute_idx = np.tile(np.arange(maxlen), (len(array), 1))
|
| 113 |
+
for i, a in enumerate(array):
|
| 114 |
+
batch_permute_idx[i, : len(a)] = np.random.permutation(len(a[:maxlen]))
|
| 115 |
+
return batch_permute_idx
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _batch_argsort(array, maxlen):
|
| 119 |
+
batch_argsort_idx = np.tile(np.arange(maxlen), (len(array), 1))
|
| 120 |
+
for i, a in enumerate(array):
|
| 121 |
+
batch_argsort_idx[i, : len(a)] = np.argsort(a[:maxlen])
|
| 122 |
+
return batch_argsort_idx
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _batch_gather(array, indices):
|
| 126 |
+
out = array.zeros_like()
|
| 127 |
+
for i, (a, idx) in enumerate(zip(array, indices)):
|
| 128 |
+
maxlen = min(len(a), len(idx))
|
| 129 |
+
out[i][:maxlen] = a[idx[:maxlen]]
|
| 130 |
+
return out
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _p4_from_pxpypze(px, py, pz, energy):
|
| 134 |
+
import vector
|
| 135 |
+
|
| 136 |
+
vector.register_awkward()
|
| 137 |
+
return vector.zip({"px": px, "py": py, "pz": pz, "energy": energy})
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _p4_from_ptetaphie(pt, eta, phi, energy):
|
| 141 |
+
import vector
|
| 142 |
+
|
| 143 |
+
vector.register_awkward()
|
| 144 |
+
return vector.zip({"pt": pt, "eta": eta, "phi": phi, "energy": energy})
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _p4_from_ptetaphim(pt, eta, phi, mass):
|
| 148 |
+
import vector
|
| 149 |
+
|
| 150 |
+
vector.register_awkward()
|
| 151 |
+
return vector.zip({"pt": pt, "eta": eta, "phi": phi, "mass": mass})
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _get_variable_names(expr, exclude=["awkward", "ak", "np", "numpy", "math"]):
|
| 155 |
+
import ast
|
| 156 |
+
|
| 157 |
+
root = ast.parse(expr)
|
| 158 |
+
return sorted(
|
| 159 |
+
{
|
| 160 |
+
node.id
|
| 161 |
+
for node in ast.walk(root)
|
| 162 |
+
if isinstance(node, ast.Name) and not node.id.startswith("_")
|
| 163 |
+
}
|
| 164 |
+
- set(exclude)
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _eval_expr(expr, table):
|
| 169 |
+
tmp = {k: table[k] for k in _get_variable_names(expr)}
|
| 170 |
+
tmp.update(
|
| 171 |
+
{
|
| 172 |
+
"math": math,
|
| 173 |
+
"np": np,
|
| 174 |
+
"numpy": np,
|
| 175 |
+
"ak": ak,
|
| 176 |
+
"awkward": ak,
|
| 177 |
+
"_concat": _concat,
|
| 178 |
+
"_stack": _stack,
|
| 179 |
+
"_pad": _pad,
|
| 180 |
+
"_repeat_pad": _repeat_pad,
|
| 181 |
+
"_clip": _clip,
|
| 182 |
+
"_batch_knn": _batch_knn,
|
| 183 |
+
"_batch_permute_indices": _batch_permute_indices,
|
| 184 |
+
"_batch_argsort": _batch_argsort,
|
| 185 |
+
"_batch_gather": _batch_gather,
|
| 186 |
+
"_p4_from_pxpypze": _p4_from_pxpypze,
|
| 187 |
+
"_p4_from_ptetaphie": _p4_from_ptetaphie,
|
| 188 |
+
"_p4_from_ptetaphim": _p4_from_ptetaphim,
|
| 189 |
+
}
|
| 190 |
+
)
|
| 191 |
+
return eval(expr, tmp)
|
src/dataset/dataclasses.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Any, List, Optional
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class PandoraFeatures:
|
| 9 |
+
# Features associated to the hits
|
| 10 |
+
pandora_cluster: Optional[Any] = None
|
| 11 |
+
pandora_cluster_energy: Optional[Any] = None
|
| 12 |
+
pfo_energy: Optional[Any] = None
|
| 13 |
+
pandora_mom: Optional[Any] = None
|
| 14 |
+
pandora_ref_point: Optional[Any] = None
|
| 15 |
+
pandora_pid: Optional[Any] = None
|
| 16 |
+
pandora_pfo_link: Optional[Any] = None
|
| 17 |
+
pandora_mom_components: Optional[Any] = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class Hits:
|
| 22 |
+
pos_xyz_hits: Any
|
| 23 |
+
pos_pxpypz: Any
|
| 24 |
+
pos_pxpypz_calo: Any
|
| 25 |
+
p_hits: Any
|
| 26 |
+
e_hits: Any
|
| 27 |
+
hit_particle_link: Any
|
| 28 |
+
pandora_features: Any # type PandoraFeatures
|
| 29 |
+
hit_type_feature: Any
|
| 30 |
+
chi_squared_tracks: Any
|
| 31 |
+
hit_type_one_hot: Any
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@classmethod
|
| 35 |
+
def from_data(cls, output, number_hits, args, number_part):
|
| 36 |
+
hit_particle_link_hits = torch.tensor(output["ygen_hit"])
|
| 37 |
+
if len(output["ygen_track"])>0:
|
| 38 |
+
hit_particle_link_tracks= torch.tensor(output["ygen_track"])
|
| 39 |
+
hit_particle_link = torch.cat((hit_particle_link_hits, hit_particle_link_tracks), dim=0)
|
| 40 |
+
else:
|
| 41 |
+
hit_particle_link = hit_particle_link_hits
|
| 42 |
+
# hit_particle_link_calomother = torch.cat((hit_particle_link_hits_calomother, hit_particle_link_tracks), dim=0)
|
| 43 |
+
if args.pandora:
|
| 44 |
+
pandora_features = PandoraFeatures()
|
| 45 |
+
X_pandora = torch.tensor(output["X_pandora"])
|
| 46 |
+
pfo_link_hits = torch.tensor(output["pfo_calohit"])
|
| 47 |
+
if len(output["pfo_track"])>0:
|
| 48 |
+
pfo_link_tracks = torch.tensor(output["pfo_track"])
|
| 49 |
+
pfo_link = torch.cat((pfo_link_hits, pfo_link_tracks), dim=0)
|
| 50 |
+
else:
|
| 51 |
+
pfo_link = pfo_link_hits
|
| 52 |
+
pandora_features.pandora_pfo_link = pfo_link
|
| 53 |
+
pfo_link_temp = pfo_link.clone()
|
| 54 |
+
pfo_link_temp[pfo_link_temp==-1]=0
|
| 55 |
+
|
| 56 |
+
pandora_features.pandora_mom = X_pandora[pfo_link_temp, 8]
|
| 57 |
+
pandora_features.pandora_ref_point = X_pandora[pfo_link_temp, 4:7]
|
| 58 |
+
pandora_features.pandora_mom_components = X_pandora[pfo_link_temp, 1:4]
|
| 59 |
+
pandora_features.pandora_pid = X_pandora[pfo_link_temp, 0]
|
| 60 |
+
pandora_features.pfo_energy = X_pandora[pfo_link_temp, 7]
|
| 61 |
+
pandora_features.pandora_mom[pfo_link==-1]=0
|
| 62 |
+
pandora_features.pandora_mom_components[pfo_link==-1]=0
|
| 63 |
+
pandora_features.pandora_ref_point[pfo_link==-1]=0
|
| 64 |
+
pandora_features.pandora_pid[pfo_link==-1]=0
|
| 65 |
+
pandora_features.pfo_energy[pfo_link==-1]=0
|
| 66 |
+
|
| 67 |
+
else:
|
| 68 |
+
pandora_features = None
|
| 69 |
+
X_hit = torch.tensor(output["X_hit"])
|
| 70 |
+
if len(output["X_track"])>0:
|
| 71 |
+
X_track = torch.tensor(output["X_track"])
|
| 72 |
+
# obtain hit type
|
| 73 |
+
|
| 74 |
+
hit_type_feature_hit = X_hit[:,10]+1 #tyep (1,2,3,4 hits)
|
| 75 |
+
if len(output["X_track"])>0:
|
| 76 |
+
hit_type_feature_track = X_track[:,0] #elemtype (1 for tracks)
|
| 77 |
+
hit_type_feature = torch.cat((hit_type_feature_hit, hit_type_feature_track), dim=0).to(torch.int64)
|
| 78 |
+
else:
|
| 79 |
+
hit_type_feature = hit_type_feature_hit.to(torch.int64)
|
| 80 |
+
# obtain the position of the hits and the energies and p
|
| 81 |
+
pos_xyz_hits_hits = X_hit[:,6:9]
|
| 82 |
+
e_hits = X_hit[:,5]
|
| 83 |
+
p_hits = X_hit[:,5]*0
|
| 84 |
+
|
| 85 |
+
if len(output["X_track"])>0:
|
| 86 |
+
pos_xyz_hits_tracks = X_track[:,12:15] #(referencePoint_calo.i)
|
| 87 |
+
pos_xyz_hits = torch.cat((pos_xyz_hits_hits, pos_xyz_hits_tracks), dim=0)
|
| 88 |
+
e_tracks =X_track[:,5]*0
|
| 89 |
+
e = torch.cat((e_hits, e_tracks), dim=0).view(-1,1)
|
| 90 |
+
p_tracks =X_track[:,5]
|
| 91 |
+
pos_pxpypz_hits_tracks = X_track[:,6:9]
|
| 92 |
+
pos_pxpypz = torch.cat((pos_xyz_hits_hits*0, pos_pxpypz_hits_tracks), dim=0)
|
| 93 |
+
pos_pxpypz_hits_tracks = X_track[:,22:]
|
| 94 |
+
pos_pxpypz_calo = torch.cat((pos_xyz_hits_hits*0, pos_pxpypz_hits_tracks), dim=0)
|
| 95 |
+
p = torch.cat((p_hits, p_tracks), dim=0).view(-1,1)
|
| 96 |
+
else:
|
| 97 |
+
pos_xyz_hits = pos_xyz_hits_hits
|
| 98 |
+
e = e_hits.view(-1,1)
|
| 99 |
+
pos_pxpypz = pos_xyz_hits_hits*0
|
| 100 |
+
pos_pxpypz_calo = pos_pxpypz
|
| 101 |
+
p = p_hits.view(-1,1)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
if len(output["X_track"])>0:
|
| 105 |
+
chi_tracks = X_track[:,15]/ X_track[:,16]
|
| 106 |
+
chi_squared_tracks = torch.cat((p_hits, chi_tracks), dim=0)
|
| 107 |
+
else:
|
| 108 |
+
chi_squared_tracks = p_hits
|
| 109 |
+
hit_type_one_hot = torch.nn.functional.one_hot(
|
| 110 |
+
hit_type_feature, num_classes=5
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
return cls(
|
| 114 |
+
pos_xyz_hits=pos_xyz_hits,
|
| 115 |
+
pos_pxpypz=pos_pxpypz,
|
| 116 |
+
pos_pxpypz_calo = pos_pxpypz_calo,
|
| 117 |
+
p_hits=p,
|
| 118 |
+
e_hits=e,
|
| 119 |
+
hit_particle_link=hit_particle_link,
|
| 120 |
+
pandora_features= pandora_features,
|
| 121 |
+
hit_type_feature=hit_type_feature,
|
| 122 |
+
chi_squared_tracks=chi_squared_tracks,
|
| 123 |
+
hit_type_one_hot = hit_type_one_hot,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
src/dataset/dataset.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file contains a modified version of the dataloader originally from:
|
| 3 |
+
|
| 4 |
+
weaver-core
|
| 5 |
+
https://github.com/hqucms/weaver-core
|
| 6 |
+
|
| 7 |
+
The original implementation has been adapted and extended for the needs of this project.
|
| 8 |
+
Please refer to the original repository for the base implementation and license details.
|
| 9 |
+
Changes in this version:
|
| 10 |
+
- Adapted to read parquet files
|
| 11 |
+
- Modified batching logic to build graphs on the fly
|
| 12 |
+
- No reweighting or standarization of dataset
|
| 13 |
+
"""
|
| 14 |
+
import os
|
| 15 |
+
import copy
|
| 16 |
+
import json
|
| 17 |
+
import numpy as np
|
| 18 |
+
import awkward as ak
|
| 19 |
+
import torch.utils.data
|
| 20 |
+
import time
|
| 21 |
+
|
| 22 |
+
from functools import partial
|
| 23 |
+
from concurrent.futures.thread import ThreadPoolExecutor
|
| 24 |
+
from src.data.tools import _pad
|
| 25 |
+
from src.data.fileio import _read_files
|
| 26 |
+
from src.data.preprocess import (
|
| 27 |
+
AutoStandardizer,
|
| 28 |
+
WeightMaker,
|
| 29 |
+
)
|
| 30 |
+
from src.dataset.functions_graph import create_graph
|
| 31 |
+
|
| 32 |
+
def _preprocess(table, options):
|
| 33 |
+
indices = np.arange(
|
| 34 |
+
len(table["X_track"])
|
| 35 |
+
)
|
| 36 |
+
if options["shuffle"]:
|
| 37 |
+
np.random.shuffle(indices)
|
| 38 |
+
return table, indices
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _load_next(filelist, load_range, options):
|
| 42 |
+
table = _read_files(
|
| 43 |
+
filelist, load_range,
|
| 44 |
+
)
|
| 45 |
+
table, indices = _preprocess(table, options)
|
| 46 |
+
return table, indices
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class _SimpleIter(object):
|
| 50 |
+
r"""_SimpleIter
|
| 51 |
+
Iterator object for ``SimpleIterDataset''.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, **kwargs):
|
| 55 |
+
# inherit all properties from SimpleIterDataset
|
| 56 |
+
self.__dict__.update(**kwargs)
|
| 57 |
+
self.iter_count = 0
|
| 58 |
+
|
| 59 |
+
# executor to read files and run preprocessing asynchronously
|
| 60 |
+
self.executor = ThreadPoolExecutor(max_workers=1) if self._async_load else None
|
| 61 |
+
|
| 62 |
+
# init: prefetch holds table and indices for the next fetch
|
| 63 |
+
self.prefetch = None
|
| 64 |
+
self.table = None
|
| 65 |
+
self.indices = []
|
| 66 |
+
self.cursor = 0
|
| 67 |
+
|
| 68 |
+
self._seed = None
|
| 69 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 70 |
+
file_dict = self._init_file_dict.copy()
|
| 71 |
+
if worker_info is not None:
|
| 72 |
+
# in a worker process
|
| 73 |
+
self._name += "_worker%d" % worker_info.id
|
| 74 |
+
self._seed = worker_info.seed & 0xFFFFFFFF
|
| 75 |
+
np.random.seed(self._seed)
|
| 76 |
+
# split workload by files
|
| 77 |
+
new_file_dict = {}
|
| 78 |
+
for name, files in file_dict.items():
|
| 79 |
+
new_files = files[worker_info.id :: worker_info.num_workers]
|
| 80 |
+
assert len(new_files) > 0
|
| 81 |
+
new_file_dict[name] = new_files
|
| 82 |
+
file_dict = new_file_dict
|
| 83 |
+
self.worker_file_dict = file_dict
|
| 84 |
+
self.worker_filelist = sum(file_dict.values(), [])
|
| 85 |
+
self.worker_info = worker_info
|
| 86 |
+
self.restart()
|
| 87 |
+
|
| 88 |
+
def restart(self):
|
| 89 |
+
print("=== Restarting DataIter %s, seed=%s ===" % (self._name, self._seed))
|
| 90 |
+
# re-shuffle filelist and load range if for training
|
| 91 |
+
filelist = self.worker_filelist.copy()
|
| 92 |
+
if self._sampler_options["shuffle"]:
|
| 93 |
+
np.random.shuffle(filelist)
|
| 94 |
+
if self._file_fraction < 1:
|
| 95 |
+
num_files = int(len(filelist) * self._file_fraction)
|
| 96 |
+
filelist = filelist[:num_files]
|
| 97 |
+
self.filelist = filelist
|
| 98 |
+
|
| 99 |
+
if self._init_load_range_and_fraction is None:
|
| 100 |
+
self.load_range = (0, 1)
|
| 101 |
+
else:
|
| 102 |
+
(start_pos, end_pos), load_frac = self._init_load_range_and_fraction
|
| 103 |
+
interval = (end_pos - start_pos) * load_frac
|
| 104 |
+
if self._sampler_options["shuffle"]:
|
| 105 |
+
offset = np.random.uniform(start_pos, end_pos - interval)
|
| 106 |
+
self.load_range = (offset, offset + interval)
|
| 107 |
+
else:
|
| 108 |
+
self.load_range = (start_pos, start_pos + interval)
|
| 109 |
+
|
| 110 |
+
self.ipos = 0 if self._fetch_by_files else self.load_range[0]
|
| 111 |
+
# prefetch the first entry asynchronously
|
| 112 |
+
self._try_get_next(init=True)
|
| 113 |
+
|
| 114 |
+
def __next__(self):
|
| 115 |
+
graph_empty = True
|
| 116 |
+
self.iter_count += 1
|
| 117 |
+
|
| 118 |
+
while graph_empty:
|
| 119 |
+
if len(self.filelist) == 0:
|
| 120 |
+
raise StopIteration
|
| 121 |
+
try:
|
| 122 |
+
i = self.indices[self.cursor]
|
| 123 |
+
except IndexError:
|
| 124 |
+
# case 1: first entry, `self.indices` is still empty
|
| 125 |
+
# case 2: running out of entries, `self.indices` is not empty
|
| 126 |
+
while True:
|
| 127 |
+
if self.prefetch is None:
|
| 128 |
+
# reaching the end as prefetch got nothing
|
| 129 |
+
self.table = None
|
| 130 |
+
if self._async_load:
|
| 131 |
+
self.executor.shutdown(wait=False)
|
| 132 |
+
raise StopIteration
|
| 133 |
+
# get result from prefetch
|
| 134 |
+
if self._async_load:
|
| 135 |
+
self.table, self.indices = self.prefetch.result()
|
| 136 |
+
else:
|
| 137 |
+
self.table, self.indices = self.prefetch
|
| 138 |
+
# try to load the next ones asynchronously
|
| 139 |
+
self._try_get_next()
|
| 140 |
+
# check if any entries are fetched (i.e., passing selection) -- if not, do another fetch
|
| 141 |
+
if len(self.indices) > 0:
|
| 142 |
+
break
|
| 143 |
+
# reset cursor
|
| 144 |
+
self.cursor = 0
|
| 145 |
+
i = self.indices[self.cursor]
|
| 146 |
+
self.cursor += 1
|
| 147 |
+
data, graph_empty = self.get_data(i)
|
| 148 |
+
return data
|
| 149 |
+
|
| 150 |
+
def _try_get_next(self, init=False):
|
| 151 |
+
end_of_list = (
|
| 152 |
+
self.ipos >= len(self.filelist)
|
| 153 |
+
if self._fetch_by_files
|
| 154 |
+
else self.ipos >= self.load_range[1]
|
| 155 |
+
)
|
| 156 |
+
if end_of_list:
|
| 157 |
+
if init:
|
| 158 |
+
raise RuntimeError(
|
| 159 |
+
"Nothing to load for worker %d" % 0
|
| 160 |
+
if self.worker_info is None
|
| 161 |
+
else self.worker_info.id
|
| 162 |
+
)
|
| 163 |
+
if self._infinity_mode and not self._in_memory:
|
| 164 |
+
# infinity mode: re-start
|
| 165 |
+
self.restart()
|
| 166 |
+
return
|
| 167 |
+
else:
|
| 168 |
+
# finite mode: set prefetch to None, exit
|
| 169 |
+
self.prefetch = None
|
| 170 |
+
return
|
| 171 |
+
if self._fetch_by_files:
|
| 172 |
+
filelist = self.filelist[int(self.ipos) : int(self.ipos + self._fetch_step)]
|
| 173 |
+
load_range = self.load_range
|
| 174 |
+
else:
|
| 175 |
+
filelist = self.filelist
|
| 176 |
+
load_range = (
|
| 177 |
+
self.ipos,
|
| 178 |
+
min(self.ipos + self._fetch_step, self.load_range[1]),
|
| 179 |
+
)
|
| 180 |
+
print('Start fetching next batch, len(filelist)=%d, load_range=%s'%(len(filelist), load_range))
|
| 181 |
+
if self._async_load:
|
| 182 |
+
self.prefetch = self.executor.submit(
|
| 183 |
+
_load_next,
|
| 184 |
+
filelist,
|
| 185 |
+
load_range,
|
| 186 |
+
self._sampler_options,
|
| 187 |
+
)
|
| 188 |
+
else:
|
| 189 |
+
self.prefetch = _load_next(
|
| 190 |
+
filelist, load_range, self._sampler_options
|
| 191 |
+
)
|
| 192 |
+
self.ipos += self._fetch_step
|
| 193 |
+
|
| 194 |
+
def get_data(self, i):
|
| 195 |
+
# inputs
|
| 196 |
+
self.args_parse.prediction = (not self.for_training)
|
| 197 |
+
# X = {k: self.table["_" + k][i].copy() for k in self._data_config.input_names}
|
| 198 |
+
X = {k: self.table[k][i] for k in self.table.fields}
|
| 199 |
+
[g, features_partnn], graph_empty = create_graph(
|
| 200 |
+
X, self.for_training, self.args_parse
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
return [g, features_partnn], graph_empty
|
| 204 |
+
# return X, False
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class SimpleIterDataset(torch.utils.data.IterableDataset):
|
| 208 |
+
r"""Base IterableDataset.
|
| 209 |
+
Handles dataloading.
|
| 210 |
+
Arguments:
|
| 211 |
+
file_dict (dict): dictionary of lists of files to be loaded.
|
| 212 |
+
data_config_file (str): YAML file containing data format information.
|
| 213 |
+
for_training (bool): flag indicating whether the dataset is used for training or testing.
|
| 214 |
+
When set to ``True``, will enable shuffling and sampling-based reweighting.
|
| 215 |
+
When set to ``False``, will disable shuffling and reweighting, but will load the observer variables.
|
| 216 |
+
load_range_and_fraction (tuple of tuples, ``((start_pos, end_pos), load_frac)``): fractional range of events to load from each file.
|
| 217 |
+
E.g., setting load_range_and_fraction=((0, 0.8), 0.5) will randomly load 50% out of the first 80% events from each file (so load 50%*80% = 40% of the file).
|
| 218 |
+
fetch_by_files (bool): flag to control how events are retrieved each time we fetch data from disk.
|
| 219 |
+
When set to ``True``, will read only a small number (set by ``fetch_step``) of files each time, but load all the events in these files.
|
| 220 |
+
When set to ``False``, will read from all input files, but load only a small fraction (set by ``fetch_step``) of events each time.
|
| 221 |
+
Default is ``False``, which results in a more uniform sample distribution but reduces the data loading speed.
|
| 222 |
+
fetch_step (float or int): fraction of events (when ``fetch_by_files=False``) or number of files (when ``fetch_by_files=True``) to load each time we fetch data from disk.
|
| 223 |
+
Event shuffling and reweighting (sampling) is performed each time after we fetch data.
|
| 224 |
+
So set this to a large enough value to avoid getting an imbalanced minibatch (due to reweighting/sampling), especially when ``fetch_by_files`` set to ``True``.
|
| 225 |
+
Will load all events (files) at once if set to non-positive value.
|
| 226 |
+
file_fraction (float): fraction of files to load.
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
def __init__(
|
| 230 |
+
self,
|
| 231 |
+
file_dict,
|
| 232 |
+
data_config_file,
|
| 233 |
+
for_training=True,
|
| 234 |
+
load_range_and_fraction=None,
|
| 235 |
+
extra_selection=None,
|
| 236 |
+
fetch_by_files=False,
|
| 237 |
+
fetch_step=0.01,
|
| 238 |
+
file_fraction=1,
|
| 239 |
+
remake_weights=False,
|
| 240 |
+
up_sample=True,
|
| 241 |
+
weight_scale=1,
|
| 242 |
+
max_resample=10,
|
| 243 |
+
async_load=True,
|
| 244 |
+
infinity_mode=False,
|
| 245 |
+
name="",
|
| 246 |
+
args_parse=None
|
| 247 |
+
):
|
| 248 |
+
self._iters = {} if infinity_mode else None
|
| 249 |
+
_init_args = set(self.__dict__.keys())
|
| 250 |
+
self._init_file_dict = file_dict
|
| 251 |
+
self._init_load_range_and_fraction = load_range_and_fraction
|
| 252 |
+
self._fetch_by_files = fetch_by_files
|
| 253 |
+
self._fetch_step = fetch_step
|
| 254 |
+
self._file_fraction = file_fraction
|
| 255 |
+
self._async_load = async_load
|
| 256 |
+
self._infinity_mode = infinity_mode
|
| 257 |
+
self._name = name
|
| 258 |
+
self.for_training = for_training
|
| 259 |
+
self.args_parse = args_parse
|
| 260 |
+
# ==== sampling parameters ====
|
| 261 |
+
self._sampler_options = {
|
| 262 |
+
"up_sample": up_sample,
|
| 263 |
+
"weight_scale": weight_scale,
|
| 264 |
+
"max_resample": max_resample,
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
if for_training:
|
| 268 |
+
self._sampler_options.update(training=True, shuffle=True, reweight=True)
|
| 269 |
+
else:
|
| 270 |
+
self._sampler_options.update(training=False, shuffle=False, reweight=False)
|
| 271 |
+
self._init_args = set(self.__dict__.keys()) - _init_args
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def __iter__(self):
|
| 276 |
+
if self._iters is None:
|
| 277 |
+
kwargs = {k: copy.deepcopy(self.__dict__[k]) for k in self._init_args}
|
| 278 |
+
return _SimpleIter(**kwargs)
|
| 279 |
+
else:
|
| 280 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 281 |
+
worker_id = worker_info.id if worker_info is not None else 0
|
| 282 |
+
try:
|
| 283 |
+
return self._iters[worker_id]
|
| 284 |
+
except KeyError:
|
| 285 |
+
kwargs = {k: copy.deepcopy(self.__dict__[k]) for k in self._init_args}
|
| 286 |
+
self._iters[worker_id] = _SimpleIter(**kwargs)
|
| 287 |
+
return self._iters[worker_id]
|
src/dataset/functions_data.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def calculate_distance_to_boundary(g):
|
| 6 |
+
r = 2150
|
| 7 |
+
r_in_endcap = 2307
|
| 8 |
+
mask_endcap = (torch.abs(g.ndata["pos_hits_xyz"][:, 2]) - r_in_endcap) > 0
|
| 9 |
+
mask_barrer = ~mask_endcap
|
| 10 |
+
weight = torch.ones_like(g.ndata["pos_hits_xyz"][:, 0])
|
| 11 |
+
C = g.ndata["pos_hits_xyz"]
|
| 12 |
+
A = torch.tensor([0, 0, 1], dtype=C.dtype, device=C.device)
|
| 13 |
+
P = (
|
| 14 |
+
r
|
| 15 |
+
* 1
|
| 16 |
+
/ (torch.norm(torch.cross(A.view(1, -1), C, dim=-1), dim=1)).unsqueeze(1)
|
| 17 |
+
* C
|
| 18 |
+
)
|
| 19 |
+
P1 = torch.abs(r_in_endcap / g.ndata["pos_hits_xyz"][:, 2].unsqueeze(1)) * C
|
| 20 |
+
weight[mask_barrer] = torch.norm(P - C, dim=1)[mask_barrer]
|
| 21 |
+
weight[mask_endcap] = torch.norm(P1[mask_endcap] - C[mask_endcap], dim=1)
|
| 22 |
+
g.ndata["radial_distance"] = weight
|
| 23 |
+
weight_ = torch.exp(-(weight / 1000))
|
| 24 |
+
g.ndata["radial_distance_exp"] = weight_
|
| 25 |
+
return g
|
| 26 |
+
|
src/dataset/functions_graph.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import dgl
|
| 4 |
+
from src.dataset.functions_data import (
|
| 5 |
+
calculate_distance_to_boundary,
|
| 6 |
+
)
|
| 7 |
+
import time
|
| 8 |
+
from src.dataset.functions_particles import concatenate_Particles_GT, Particles_GT
|
| 9 |
+
|
| 10 |
+
from src.dataset.dataclasses import Hits
|
| 11 |
+
|
| 12 |
+
def create_inputs_from_table(
|
| 13 |
+
output, prediction=False, args=None
|
| 14 |
+
):
|
| 15 |
+
number_hits = np.int32(len(output["X_track"])+len(output["X_hit"]))
|
| 16 |
+
number_part = np.int32(len(output["X_gen"]))
|
| 17 |
+
|
| 18 |
+
hits = Hits.from_data(
|
| 19 |
+
output,
|
| 20 |
+
number_hits,
|
| 21 |
+
args,
|
| 22 |
+
number_part
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
y_data_graph = Particles_GT()
|
| 26 |
+
y_data_graph.fill( output, prediction,args)
|
| 27 |
+
|
| 28 |
+
result = [
|
| 29 |
+
y_data_graph,
|
| 30 |
+
hits
|
| 31 |
+
]
|
| 32 |
+
return result
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def create_graph(
|
| 38 |
+
output,
|
| 39 |
+
for_training =True, args=None
|
| 40 |
+
):
|
| 41 |
+
prediction = not for_training
|
| 42 |
+
graph_empty = False
|
| 43 |
+
|
| 44 |
+
result = create_inputs_from_table(
|
| 45 |
+
output,
|
| 46 |
+
prediction=prediction,
|
| 47 |
+
args=args
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
if len(result) == 1:
|
| 51 |
+
graph_empty = True
|
| 52 |
+
return [0, 0], graph_empty
|
| 53 |
+
else:
|
| 54 |
+
(y_data_graph,hits) = result
|
| 55 |
+
|
| 56 |
+
g = dgl.graph(([], []))
|
| 57 |
+
g.add_nodes(hits.pos_xyz_hits.shape[0])
|
| 58 |
+
g.ndata["h"] = torch.cat(
|
| 59 |
+
(hits.pos_xyz_hits, hits.hit_type_one_hot, hits.e_hits, hits.p_hits), dim=1
|
| 60 |
+
).float()
|
| 61 |
+
g.ndata["p_hits"] = hits.p_hits.float()
|
| 62 |
+
g.ndata["pos_hits_xyz"] = hits.pos_xyz_hits.float()
|
| 63 |
+
g.ndata["pos_pxpypz_at_vertex"] = hits.pos_pxpypz.float()
|
| 64 |
+
g.ndata["pos_pxpypz"] = hits.pos_pxpypz #TrackState::AtIP
|
| 65 |
+
g.ndata["pos_pxpypz_at_calo"] = hits.pos_pxpypz_calo #TrackState::AtCalorimeter
|
| 66 |
+
g = calculate_distance_to_boundary(g)
|
| 67 |
+
g.ndata["hit_type"] = hits.hit_type_feature.float()
|
| 68 |
+
g.ndata["e_hits"] = hits.e_hits.float()
|
| 69 |
+
|
| 70 |
+
g.ndata["chi_squared_tracks"] = hits.chi_squared_tracks.float()
|
| 71 |
+
g.ndata["particle_number"] = hits.hit_particle_link.float()+1 #(noise idx is 0 and particle MC 0 starts at 1)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if prediction and (args.pandora):
|
| 75 |
+
g.ndata["pandora_pfo"] = hits.pandora_features.pandora_pfo_link.float()
|
| 76 |
+
g.ndata["pandora_pfo_energy"] = hits.pandora_features.pfo_energy.float()
|
| 77 |
+
g.ndata["pandora_momentum"] = hits.pandora_features.pandora_mom_components.float()
|
| 78 |
+
g.ndata["pandora_reference_point"] = hits.pandora_features.pandora_ref_point.float()
|
| 79 |
+
g.ndata["pandora_pid"] = hits.pandora_features.pandora_pid.float()
|
| 80 |
+
graph_empty = False
|
| 81 |
+
unique_links = torch.unique(hits.hit_particle_link)
|
| 82 |
+
if not prediction and unique_links.shape[0] == 1 and unique_links[0] == -1:
|
| 83 |
+
graph_empty = True
|
| 84 |
+
if hits.pos_xyz_hits.shape[0] < 10:
|
| 85 |
+
graph_empty = True
|
| 86 |
+
|
| 87 |
+
return [g, y_data_graph], graph_empty
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def graph_batch_func(list_graphs):
|
| 94 |
+
"""collator function for graph dataloader
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
list_graphs (list): list of graphs from the iterable dataset
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
batch dgl: dgl batch of graphs
|
| 101 |
+
"""
|
| 102 |
+
list_graphs_g = [el[0] for el in list_graphs]
|
| 103 |
+
ys = concatenate_Particles_GT(list_graphs)
|
| 104 |
+
bg = dgl.batch(list_graphs_g)
|
| 105 |
+
return bg, ys
|
src/dataset/functions_particles.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from sklearn.preprocessing import StandardScaler
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, List, Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class Particles_GT():
|
| 10 |
+
|
| 11 |
+
angle: Optional[Any] = None
|
| 12 |
+
coord: Optional[Any] = None
|
| 13 |
+
E: Optional[Any] = None
|
| 14 |
+
E_corrected: Optional[Any] = None
|
| 15 |
+
m: Optional[Any] = None
|
| 16 |
+
mass: Optional[Any] = None
|
| 17 |
+
pid: Optional[Any] = None
|
| 18 |
+
vertex: Optional[Any] = None
|
| 19 |
+
gen_status: Optional[Any] = None
|
| 20 |
+
batch_number: Optional[Any] = None
|
| 21 |
+
endpoint: Optional[Any] = None
|
| 22 |
+
|
| 23 |
+
def fill(self, output, prediction, args):
|
| 24 |
+
|
| 25 |
+
features_particles = torch.tensor(output["X_gen"])
|
| 26 |
+
particle_coord_angle = features_particles[:,4:6]
|
| 27 |
+
particle_coord = features_particles[:, 12:15]
|
| 28 |
+
vertex_coord = features_particles[:, 15:18]
|
| 29 |
+
|
| 30 |
+
y_mass = features_particles[:, 10].view(-1).unsqueeze(1)
|
| 31 |
+
y_mom = features_particles[:, 11].view(-1).unsqueeze(1)
|
| 32 |
+
y_energy = features_particles[:, 8].view(-1).unsqueeze(1)
|
| 33 |
+
y_pid = features_particles[:,0]
|
| 34 |
+
gen_status = features_particles[:,1]
|
| 35 |
+
|
| 36 |
+
self.angle= particle_coord_angle
|
| 37 |
+
self.coord = particle_coord
|
| 38 |
+
self.E_corrected = y_energy
|
| 39 |
+
self.E = y_energy
|
| 40 |
+
self.m = y_mom
|
| 41 |
+
self.mass = y_mass
|
| 42 |
+
self.pid = y_pid
|
| 43 |
+
self.vertex=vertex_coord
|
| 44 |
+
self.gen_status = gen_status
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def __len__(self):
|
| 48 |
+
return len(self.E)
|
| 49 |
+
|
| 50 |
+
def mask(self, mask):
|
| 51 |
+
for k in self.__dict__:
|
| 52 |
+
if getattr(self, k) is not None:
|
| 53 |
+
if type(getattr(self, k)) == list:
|
| 54 |
+
if getattr(self, k)[0] is not None:
|
| 55 |
+
setattr(self, k, getattr(self, k)[mask])
|
| 56 |
+
else:
|
| 57 |
+
setattr(self, k, getattr(self, k)[mask])
|
| 58 |
+
|
| 59 |
+
def copy(self):
|
| 60 |
+
obj = type(self).__new__(self.__class__)
|
| 61 |
+
obj.__dict__.update(self.__dict__)
|
| 62 |
+
return obj
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def concatenate_Particles_GT(list_of_Particles_GT):
|
| 67 |
+
list_coord = [p[1].coord for p in list_of_Particles_GT]
|
| 68 |
+
list_angle = [p[1].angle for p in list_of_Particles_GT]
|
| 69 |
+
list_angle = torch.cat(list_angle, dim=0)
|
| 70 |
+
list_vertex = [p[1].vertex for p in list_of_Particles_GT]
|
| 71 |
+
list_coord = torch.cat(list_coord, dim=0)
|
| 72 |
+
list_E = [p[1].E for p in list_of_Particles_GT]
|
| 73 |
+
list_E = torch.cat(list_E, dim=0)
|
| 74 |
+
list_E_corr = [p[1].E_corrected for p in list_of_Particles_GT]
|
| 75 |
+
list_E_corr = torch.cat(list_E_corr, dim=0)
|
| 76 |
+
list_m = [p[1].m for p in list_of_Particles_GT]
|
| 77 |
+
list_m = torch.cat(list_m, dim=0)
|
| 78 |
+
list_mass = [p[1].mass for p in list_of_Particles_GT]
|
| 79 |
+
list_mass = torch.cat(list_mass, dim=0)
|
| 80 |
+
list_pid = [p[1].pid for p in list_of_Particles_GT]
|
| 81 |
+
list_pid = torch.cat(list_pid, dim=0)
|
| 82 |
+
list_genstatus = [p[1].gen_status for p in list_of_Particles_GT]
|
| 83 |
+
list_genstatus = torch.cat(list_genstatus, dim=0)
|
| 84 |
+
if hasattr(list_of_Particles_GT[0], "endpoint"):
|
| 85 |
+
list_endpoint = [p[1].endpoint for p in list_of_Particles_GT]
|
| 86 |
+
list_endpoint= torch.cat(list_endpoint, dim=0)
|
| 87 |
+
else:
|
| 88 |
+
list_endpoint = None
|
| 89 |
+
if list_vertex[0] is not None:
|
| 90 |
+
list_vertex = torch.cat(list_vertex, dim=0)
|
| 91 |
+
if hasattr(list_of_Particles_GT[0], "decayed_in_calo"):
|
| 92 |
+
list_dec_calo = [p[1].decayed_in_calo for p in list_of_Particles_GT]
|
| 93 |
+
list_dec_track = [p[1].decayed_in_tracker for p in list_of_Particles_GT]
|
| 94 |
+
list_dec_calo = torch.cat(list_dec_calo, dim=0)
|
| 95 |
+
list_dec_track = torch.cat(list_dec_track, dim=0)
|
| 96 |
+
else:
|
| 97 |
+
list_dec_calo = None
|
| 98 |
+
list_dec_track = None
|
| 99 |
+
batch_number = add_batch_number(list_of_Particles_GT)
|
| 100 |
+
particle_batch = Particles_GT()
|
| 101 |
+
particle_batch.angle = list_angle
|
| 102 |
+
particle_batch.coord = list_coord
|
| 103 |
+
particle_batch.E = list_E
|
| 104 |
+
particle_batch.E_corrected = list_E_corr
|
| 105 |
+
particle_batch.m = list_m
|
| 106 |
+
particle_batch.pid = list_pid
|
| 107 |
+
particle_batch.vertex= list_vertex
|
| 108 |
+
particle_batch.decayed_in_calo = list_dec_calo
|
| 109 |
+
particle_batch.decayed_in_tracker = list_dec_track
|
| 110 |
+
particle_batch.batch_number = batch_number
|
| 111 |
+
particle_batch.gen_status = list_genstatus
|
| 112 |
+
particle_batch.endpoint = list_endpoint
|
| 113 |
+
return particle_batch
|
| 114 |
+
|
| 115 |
+
def add_batch_number(list_graphs):
|
| 116 |
+
list_y = []
|
| 117 |
+
for i, el in enumerate(list_graphs):
|
| 118 |
+
y = el[1]
|
| 119 |
+
batch_id = torch.ones(y.E.shape[0], 1) * i
|
| 120 |
+
list_y.append(batch_id)
|
| 121 |
+
list_y = torch.cat(list_y, dim=0)
|
| 122 |
+
return list_y
|
src/inference.py
ADDED
|
@@ -0,0 +1,735 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Standalone single-event MLPF inference.
|
| 3 |
+
|
| 4 |
+
Provides :func:`run_single_event_inference` which takes raw event data
|
| 5 |
+
(from a parquet file or as an awkward record) and model checkpoint paths,
|
| 6 |
+
runs the full particle-flow pipeline (graph construction → GATr forward
|
| 7 |
+
pass → density-peak clustering → energy correction & PID), and returns:
|
| 8 |
+
|
| 9 |
+
* a ``pandas.DataFrame`` of predicted particles with their properties
|
| 10 |
+
* a hit→cluster mapping as a ``pandas.DataFrame``
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import types
|
| 15 |
+
from typing import Optional
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pandas as pd
|
| 18 |
+
import torch
|
| 19 |
+
import dgl
|
| 20 |
+
import awkward as ak
|
| 21 |
+
|
| 22 |
+
from src.data.fileio import _read_parquet
|
| 23 |
+
from src.dataset.functions_graph import create_graph
|
| 24 |
+
from src.dataset.functions_particles import Particles_GT, add_batch_number
|
| 25 |
+
from src.layers.clustering import DPC_custom_CLD, remove_bad_tracks_from_cluster
|
| 26 |
+
from src.utils.pid_conversion import pid_conversion_dict
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# -- CPU-compatible attention patch ------------------------------------------
|
| 30 |
+
|
| 31 |
+
def _patch_gatr_attention_for_cpu():
|
| 32 |
+
"""Replace GATr's xformers-based attention with a naive implementation.
|
| 33 |
+
|
| 34 |
+
``xformers.ops.fmha.memory_efficient_attention`` has no CPU kernel, so
|
| 35 |
+
running GATr on CPU crashes. This function monkey-patches
|
| 36 |
+
``gatr.primitives.attention.scaled_dot_product_attention`` with a plain
|
| 37 |
+
PyTorch implementation that works on any device (albeit slower on GPU).
|
| 38 |
+
The patch is applied at most once.
|
| 39 |
+
"""
|
| 40 |
+
import gatr.primitives.attention as _gatr_attn
|
| 41 |
+
|
| 42 |
+
if getattr(_gatr_attn, "_cpu_patched", False):
|
| 43 |
+
return
|
| 44 |
+
|
| 45 |
+
def _cpu_sdpa(q, k, v, attn_mask=None):
|
| 46 |
+
# q, k, v: (B, H, N, D) — batch, heads, items, dim
|
| 47 |
+
B, H, N, D = q.shape
|
| 48 |
+
scale = float(D) ** -0.5
|
| 49 |
+
|
| 50 |
+
q2 = q.reshape(B * H, N, D)
|
| 51 |
+
k2 = k.reshape(B * H, N, D)
|
| 52 |
+
v2 = v.reshape(B * H, N, D)
|
| 53 |
+
|
| 54 |
+
attn = torch.bmm(q2 * scale, k2.transpose(1, 2)) # (B*H, N, N)
|
| 55 |
+
|
| 56 |
+
if attn_mask is not None:
|
| 57 |
+
dense = _block_diag_mask_to_dense(attn_mask, N, q.device)
|
| 58 |
+
if dense is not None:
|
| 59 |
+
attn = attn.masked_fill(~dense.unsqueeze(0), float("-inf"))
|
| 60 |
+
|
| 61 |
+
attn = torch.softmax(attn, dim=-1)
|
| 62 |
+
# Rows that are fully masked produce NaN after softmax; zero them out.
|
| 63 |
+
attn = attn.nan_to_num(0.0)
|
| 64 |
+
|
| 65 |
+
out = torch.bmm(attn, v2) # (B*H, N, D)
|
| 66 |
+
return out.reshape(B, H, N, D)
|
| 67 |
+
|
| 68 |
+
_gatr_attn.scaled_dot_product_attention = _cpu_sdpa
|
| 69 |
+
_gatr_attn._cpu_patched = True
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _block_diag_mask_to_dense(attn_mask, total_len, device):
|
| 73 |
+
"""Convert an ``xformers.ops.fmha.BlockDiagonalMask`` to a dense bool mask."""
|
| 74 |
+
try:
|
| 75 |
+
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
| 76 |
+
if not isinstance(attn_mask, BlockDiagonalMask):
|
| 77 |
+
return None
|
| 78 |
+
except ImportError:
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
# Extract per-sequence start offsets
|
| 82 |
+
try:
|
| 83 |
+
seqstarts = attn_mask.q_seqinfo.seqstart_py
|
| 84 |
+
except AttributeError:
|
| 85 |
+
try:
|
| 86 |
+
seqstarts = attn_mask.q_seqinfo.seqstart.cpu().tolist()
|
| 87 |
+
except Exception:
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
mask = torch.zeros(total_len, total_len, dtype=torch.bool, device=device)
|
| 91 |
+
for i in range(len(seqstarts) - 1):
|
| 92 |
+
s, e = seqstarts[i], seqstarts[i + 1]
|
| 93 |
+
mask[s:e, s:e] = True
|
| 94 |
+
return mask
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# -- PID label → human-readable name ----------------------------------------
|
| 98 |
+
|
| 99 |
+
_PID_LABELS = {
|
| 100 |
+
0: "electron",
|
| 101 |
+
1: "charged hadron",
|
| 102 |
+
2: "neutral hadron",
|
| 103 |
+
3: "photon",
|
| 104 |
+
4: "muon",
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
_ABS_PDG_NAME = {
|
| 108 |
+
11: "electron",
|
| 109 |
+
13: "muon",
|
| 110 |
+
22: "photon",
|
| 111 |
+
130: "K_L",
|
| 112 |
+
211: "pion±",
|
| 113 |
+
321: "kaon±",
|
| 114 |
+
2112: "neutron",
|
| 115 |
+
2212: "proton",
|
| 116 |
+
310: "K_S",
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# -- Minimal args namespace for inference ------------------------------------
|
| 121 |
+
|
| 122 |
+
def _default_args(**overrides):
|
| 123 |
+
"""Return a minimal ``argparse.Namespace`` with defaults the model expects."""
|
| 124 |
+
d = dict(
|
| 125 |
+
correction=True,
|
| 126 |
+
freeze_clustering=True,
|
| 127 |
+
predict=True,
|
| 128 |
+
pandora=False,
|
| 129 |
+
use_gt_clusters=False,
|
| 130 |
+
use_average_cc_pos=0.99,
|
| 131 |
+
qmin=1.0,
|
| 132 |
+
data_config="config_files/config_hits_track_v4.yaml",
|
| 133 |
+
network_config="src/models/wrapper/example_mode_gatr_noise.py",
|
| 134 |
+
model_prefix="/tmp/mlpf_eval",
|
| 135 |
+
start_lr=1e-3,
|
| 136 |
+
frac_cluster_loss=0,
|
| 137 |
+
local_rank=0,
|
| 138 |
+
gpus="0",
|
| 139 |
+
batch_size=1,
|
| 140 |
+
num_workers=0,
|
| 141 |
+
prefetch_factor=1,
|
| 142 |
+
num_epochs=1,
|
| 143 |
+
steps_per_epoch=None,
|
| 144 |
+
samples_per_epoch=None,
|
| 145 |
+
steps_per_epoch_val=None,
|
| 146 |
+
samples_per_epoch_val=None,
|
| 147 |
+
train_val_split=0.8,
|
| 148 |
+
data_train=[],
|
| 149 |
+
data_val=[],
|
| 150 |
+
data_test=[],
|
| 151 |
+
data_fraction=1,
|
| 152 |
+
file_fraction=1,
|
| 153 |
+
fetch_by_files=True,
|
| 154 |
+
fetch_step=1,
|
| 155 |
+
log_wandb=False,
|
| 156 |
+
wandb_displayname="",
|
| 157 |
+
wandb_projectname="",
|
| 158 |
+
wandb_entity="",
|
| 159 |
+
name_output="gradio",
|
| 160 |
+
train_batches=100,
|
| 161 |
+
)
|
| 162 |
+
d.update(overrides)
|
| 163 |
+
return argparse.Namespace(**d)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# -- Model loading -----------------------------------------------------------
|
| 167 |
+
|
| 168 |
+
def load_model(
|
| 169 |
+
clustering_ckpt: str,
|
| 170 |
+
energy_pid_ckpt: Optional[str] = None,
|
| 171 |
+
device: str = "cpu",
|
| 172 |
+
args_overrides: Optional[dict] = None,
|
| 173 |
+
):
|
| 174 |
+
"""Load the full MLPF model (clustering + optional energy/PID correction).
|
| 175 |
+
|
| 176 |
+
Parameters
|
| 177 |
+
----------
|
| 178 |
+
clustering_ckpt : str
|
| 179 |
+
Path to the clustering checkpoint (``.ckpt``).
|
| 180 |
+
energy_pid_ckpt : str or None
|
| 181 |
+
Path to the energy-correction / PID checkpoint (``.ckpt``).
|
| 182 |
+
If *None*, only clustering is performed (no energy correction / PID).
|
| 183 |
+
device : str
|
| 184 |
+
``"cpu"`` or ``"cuda:0"`` etc.
|
| 185 |
+
args_overrides : dict or None
|
| 186 |
+
Extra key-value pairs forwarded to :func:`_default_args`.
|
| 187 |
+
|
| 188 |
+
Returns
|
| 189 |
+
-------
|
| 190 |
+
model : ExampleWrapper
|
| 191 |
+
The model in eval mode, on *device*.
|
| 192 |
+
args : argparse.Namespace
|
| 193 |
+
The arguments namespace used.
|
| 194 |
+
"""
|
| 195 |
+
from src.models.Gatr_pf_e_noise import ExampleWrapper
|
| 196 |
+
|
| 197 |
+
overrides = dict(args_overrides or {})
|
| 198 |
+
has_correction = energy_pid_ckpt is not None
|
| 199 |
+
overrides["correction"] = has_correction
|
| 200 |
+
|
| 201 |
+
args = _default_args(**overrides)
|
| 202 |
+
dev = torch.device(device)
|
| 203 |
+
|
| 204 |
+
if has_correction:
|
| 205 |
+
ckpt = torch.load(energy_pid_ckpt, map_location=dev)
|
| 206 |
+
state_dict = ckpt["state_dict"]
|
| 207 |
+
model = ExampleWrapper(args=args, dev=0)
|
| 208 |
+
model.load_state_dict(state_dict, strict=False)
|
| 209 |
+
# Overwrite clustering layers from clustering checkpoint
|
| 210 |
+
model2 = ExampleWrapper.load_from_checkpoint(
|
| 211 |
+
clustering_ckpt, args=args, dev=0, strict=False, map_location=dev,
|
| 212 |
+
)
|
| 213 |
+
model.gatr = model2.gatr
|
| 214 |
+
model.ScaledGooeyBatchNorm2_1 = model2.ScaledGooeyBatchNorm2_1
|
| 215 |
+
model.clustering = model2.clustering
|
| 216 |
+
model.beta = model2.beta
|
| 217 |
+
else:
|
| 218 |
+
model = ExampleWrapper.load_from_checkpoint(
|
| 219 |
+
clustering_ckpt, args=args, dev=0, strict=False, map_location=dev,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
model = model.to(dev)
|
| 223 |
+
model.eval()
|
| 224 |
+
return model, args
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def load_random_model(
|
| 228 |
+
device: str = "cpu",
|
| 229 |
+
args_overrides: Optional[dict] = None,
|
| 230 |
+
):
|
| 231 |
+
"""Create a GATr model with randomly initialised weights (no checkpoint).
|
| 232 |
+
|
| 233 |
+
This is useful for debugging to verify that checkpoint weights are
|
| 234 |
+
actually being loaded and used by the model.
|
| 235 |
+
|
| 236 |
+
Parameters
|
| 237 |
+
----------
|
| 238 |
+
device : str
|
| 239 |
+
``"cpu"`` or ``"cuda:0"`` etc.
|
| 240 |
+
args_overrides : dict or None
|
| 241 |
+
Extra key-value pairs forwarded to :func:`_default_args`.
|
| 242 |
+
|
| 243 |
+
Returns
|
| 244 |
+
-------
|
| 245 |
+
model : ExampleWrapper
|
| 246 |
+
The model (random weights) in eval mode, on *device*.
|
| 247 |
+
args : argparse.Namespace
|
| 248 |
+
The arguments namespace used.
|
| 249 |
+
"""
|
| 250 |
+
from src.models.Gatr_pf_e_noise import ExampleWrapper
|
| 251 |
+
|
| 252 |
+
overrides = dict(args_overrides or {})
|
| 253 |
+
overrides["correction"] = False
|
| 254 |
+
|
| 255 |
+
args = _default_args(**overrides)
|
| 256 |
+
dev = torch.device(device)
|
| 257 |
+
|
| 258 |
+
model = ExampleWrapper(args=args, dev=0)
|
| 259 |
+
model = model.to(dev)
|
| 260 |
+
model.eval()
|
| 261 |
+
return model, args
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# -- Single-event data loading -----------------------------------------------
|
| 265 |
+
|
| 266 |
+
def load_event_from_parquet(parquet_path: str, event_index: int = 0):
|
| 267 |
+
"""Read a single event from a parquet file.
|
| 268 |
+
|
| 269 |
+
Returns an awkward record with fields ``X_hit``, ``X_track``, ``X_gen``,
|
| 270 |
+
``ygen_hit``, ``ygen_track``, etc.
|
| 271 |
+
"""
|
| 272 |
+
table = _read_parquet(parquet_path)
|
| 273 |
+
n_events = len(table["X_track"])
|
| 274 |
+
if event_index >= n_events:
|
| 275 |
+
raise IndexError(
|
| 276 |
+
f"event_index {event_index} out of range (file has {n_events} events)"
|
| 277 |
+
)
|
| 278 |
+
event = {field: table[field][event_index] for field in table.fields}
|
| 279 |
+
return event
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# -- Core inference function --------------------------------------------------
|
| 283 |
+
|
| 284 |
+
@torch.no_grad()
|
| 285 |
+
def run_single_event_inference(
|
| 286 |
+
event,
|
| 287 |
+
model,
|
| 288 |
+
args,
|
| 289 |
+
device: str = "cpu",
|
| 290 |
+
):
|
| 291 |
+
"""Run full MLPF inference on a single event.
|
| 292 |
+
|
| 293 |
+
Parameters
|
| 294 |
+
----------
|
| 295 |
+
event : dict-like
|
| 296 |
+
A single event record (from :func:`load_event_from_parquet`).
|
| 297 |
+
model : ExampleWrapper
|
| 298 |
+
The loaded model (from :func:`load_model`).
|
| 299 |
+
args : argparse.Namespace
|
| 300 |
+
The arguments namespace (from :func:`load_model`).
|
| 301 |
+
device : str
|
| 302 |
+
Device string.
|
| 303 |
+
|
| 304 |
+
Returns
|
| 305 |
+
-------
|
| 306 |
+
particles_df : pandas.DataFrame
|
| 307 |
+
One row per predicted particle with columns:
|
| 308 |
+
``cluster_id``, ``energy``, ``pid_class``, ``pid_label``,
|
| 309 |
+
``px``, ``py``, ``pz``, ``is_charged``.
|
| 310 |
+
hit_cluster_df : pandas.DataFrame
|
| 311 |
+
One row per hit with columns:
|
| 312 |
+
``hit_index``, ``cluster_id``, ``pandora_cluster_id``,
|
| 313 |
+
``hit_type_id``, ``hit_type``, ``x``, ``y``, ``z``,
|
| 314 |
+
``hit_energy``, ``cluster_x``, ``cluster_y``, ``cluster_z``.
|
| 315 |
+
``pandora_cluster_id`` is -1 when pandora data is not available
|
| 316 |
+
or when the hit has no matching entry (e.g. CSV was modified after
|
| 317 |
+
loading from parquet).
|
| 318 |
+
mc_particles_df : pandas.DataFrame
|
| 319 |
+
One row per MC truth particle with columns:
|
| 320 |
+
``pid``, ``energy``, ``momentum``, ``px``, ``py``, ``pz``,
|
| 321 |
+
``mass``, ``theta``, ``phi``, ``vx``, ``vy``, ``vz``,
|
| 322 |
+
``gen_status``, ``pdg_name``.
|
| 323 |
+
pandora_particles_df : pandas.DataFrame
|
| 324 |
+
One row per Pandora PFO with columns:
|
| 325 |
+
``pfo_idx``, ``pid``, ``pdg_name``, ``energy``, ``momentum``,
|
| 326 |
+
``px``, ``py``, ``pz``, ``ref_x``, ``ref_y``, ``ref_z``.
|
| 327 |
+
Empty when pandora data is not available in the input.
|
| 328 |
+
"""
|
| 329 |
+
dev = torch.device(device)
|
| 330 |
+
|
| 331 |
+
# Ensure eval mode so that BatchNorm layers use running statistics from
|
| 332 |
+
# training instead of computing batch statistics from the current
|
| 333 |
+
# (single-event) input. Without this, inference with batch_size=1
|
| 334 |
+
# produces incorrect normalization.
|
| 335 |
+
model.eval()
|
| 336 |
+
|
| 337 |
+
if dev.type == "cpu":
|
| 338 |
+
_patch_gatr_attention_for_cpu()
|
| 339 |
+
|
| 340 |
+
# 0. Extract MC truth particles table and pandora particles
|
| 341 |
+
mc_particles_df = _extract_mc_particles(event)
|
| 342 |
+
pandora_particles_df, pfo_calohit, pfo_track = _extract_pandora_particles(event)
|
| 343 |
+
|
| 344 |
+
# 1. Build DGL graph from the event
|
| 345 |
+
[g, y_data], graph_empty = create_graph(event, for_training=False, args=args)
|
| 346 |
+
if graph_empty:
|
| 347 |
+
return pd.DataFrame(), pd.DataFrame(), mc_particles_df, pandora_particles_df
|
| 348 |
+
|
| 349 |
+
g = g.to(dev)
|
| 350 |
+
# Prepare batch metadata expected by the model
|
| 351 |
+
y_data.batch_number = torch.zeros(y_data.E.shape[0], 1)
|
| 352 |
+
|
| 353 |
+
# 2. Forward pass through the GATr clustering backbone
|
| 354 |
+
inputs = g.ndata["pos_hits_xyz"].float().to(dev)
|
| 355 |
+
inputs_scalar = g.ndata["hit_type"].float().view(-1, 1).to(dev)
|
| 356 |
+
|
| 357 |
+
from gatr.interface import embed_point, embed_scalar
|
| 358 |
+
from xformers.ops.fmha import BlockDiagonalMask
|
| 359 |
+
|
| 360 |
+
inputs_normed = model.ScaledGooeyBatchNorm2_1(inputs)
|
| 361 |
+
embedded_inputs = embed_point(inputs_normed) + embed_scalar(inputs_scalar)
|
| 362 |
+
embedded_inputs = embedded_inputs.unsqueeze(-2)
|
| 363 |
+
mask = BlockDiagonalMask.from_seqlens([g.num_nodes()])
|
| 364 |
+
scalars = torch.cat(
|
| 365 |
+
(g.ndata["e_hits"].float().to(dev), g.ndata["p_hits"].float().to(dev)), dim=1
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
from gatr.interface import extract_point, extract_scalar
|
| 369 |
+
|
| 370 |
+
embedded_outputs, scalar_outputs = model.gatr(
|
| 371 |
+
embedded_inputs, scalars=scalars, attention_mask=mask
|
| 372 |
+
)
|
| 373 |
+
points = extract_point(embedded_outputs[:, 0, :])
|
| 374 |
+
nodewise_outputs = extract_scalar(embedded_outputs)
|
| 375 |
+
x_point = points
|
| 376 |
+
x_scalar = torch.cat(
|
| 377 |
+
(nodewise_outputs.view(-1, 1), scalar_outputs.view(-1, 1)), dim=1
|
| 378 |
+
)
|
| 379 |
+
x_cluster_coord = model.clustering(x_point)
|
| 380 |
+
beta = model.beta(x_scalar)
|
| 381 |
+
|
| 382 |
+
g.ndata["final_cluster"] = x_cluster_coord
|
| 383 |
+
g.ndata["beta"] = beta.view(-1)
|
| 384 |
+
|
| 385 |
+
# 3. Density-peak clustering
|
| 386 |
+
labels = DPC_custom_CLD(x_cluster_coord, g, dev)
|
| 387 |
+
labels, _ = remove_bad_tracks_from_cluster(g, labels)
|
| 388 |
+
|
| 389 |
+
# 4. Build hit→cluster table
|
| 390 |
+
n_hits = g.num_nodes()
|
| 391 |
+
hit_types_raw = g.ndata["hit_type"].cpu().numpy()
|
| 392 |
+
hit_type_names = {1: "track", 2: "ECAL", 3: "HCAL", 4: "muon"}
|
| 393 |
+
|
| 394 |
+
# Build pandora cluster ID per node (hits first, then tracks)
|
| 395 |
+
# Use min of array lengths for graceful handling when CSV was modified
|
| 396 |
+
n_calo = len(np.asarray(event.get("X_hit", [])))
|
| 397 |
+
pandora_cluster_ids = np.full(n_hits, -1, dtype=np.int64)
|
| 398 |
+
if len(pfo_calohit) > 0:
|
| 399 |
+
n_assign = min(len(pfo_calohit), n_calo)
|
| 400 |
+
pandora_cluster_ids[:n_assign] = pfo_calohit[:n_assign]
|
| 401 |
+
n_tracks = n_hits - n_calo
|
| 402 |
+
if n_tracks > 0 and len(pfo_track) > 0:
|
| 403 |
+
n_assign = min(len(pfo_track), n_tracks)
|
| 404 |
+
pandora_cluster_ids[n_calo:n_calo + n_assign] = pfo_track[:n_assign]
|
| 405 |
+
|
| 406 |
+
hit_cluster_df = pd.DataFrame({
|
| 407 |
+
"hit_index": np.arange(n_hits),
|
| 408 |
+
"cluster_id": labels.cpu().numpy(),
|
| 409 |
+
"pandora_cluster_id": pandora_cluster_ids,
|
| 410 |
+
"hit_type_id": hit_types_raw,
|
| 411 |
+
"hit_type": [hit_type_names.get(int(t), str(int(t))) for t in hit_types_raw],
|
| 412 |
+
"x": g.ndata["pos_hits_xyz"][:, 0].cpu().numpy(),
|
| 413 |
+
"y": g.ndata["pos_hits_xyz"][:, 1].cpu().numpy(),
|
| 414 |
+
"z": g.ndata["pos_hits_xyz"][:, 2].cpu().numpy(),
|
| 415 |
+
"hit_energy": g.ndata["e_hits"].view(-1).cpu().numpy(),
|
| 416 |
+
"cluster_x": x_cluster_coord[:, 0].cpu().numpy(),
|
| 417 |
+
"cluster_y": x_cluster_coord[:, 1].cpu().numpy(),
|
| 418 |
+
"cluster_z": x_cluster_coord[:, 2].cpu().numpy(),
|
| 419 |
+
})
|
| 420 |
+
|
| 421 |
+
# 5. Per-cluster summary (basic, before energy correction)
|
| 422 |
+
unique_labels = torch.unique(labels)
|
| 423 |
+
# cluster 0 = noise
|
| 424 |
+
cluster_ids = unique_labels[unique_labels > 0].cpu().numpy()
|
| 425 |
+
|
| 426 |
+
from torch_scatter import scatter_add
|
| 427 |
+
|
| 428 |
+
e_per_cluster = scatter_add(
|
| 429 |
+
g.ndata["e_hits"].view(-1).to(dev), labels.to(dev)
|
| 430 |
+
)
|
| 431 |
+
p_per_cluster = scatter_add(
|
| 432 |
+
g.ndata["p_hits"].view(-1).to(dev), labels.to(dev)
|
| 433 |
+
)
|
| 434 |
+
n_hits_per_cluster = scatter_add(
|
| 435 |
+
torch.ones(n_hits, device=dev), labels.to(dev)
|
| 436 |
+
)
|
| 437 |
+
# Check if any cluster has a track (→ charged)
|
| 438 |
+
is_track_per_cluster = scatter_add(
|
| 439 |
+
(g.ndata["hit_type"].to(dev) == 1).float(), labels.to(dev)
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
rows = []
|
| 443 |
+
for cid in cluster_ids:
|
| 444 |
+
mask_c = labels == cid
|
| 445 |
+
e_sum = e_per_cluster[cid].item()
|
| 446 |
+
p_sum = p_per_cluster[cid].item()
|
| 447 |
+
n_h = int(n_hits_per_cluster[cid].item())
|
| 448 |
+
has_track = is_track_per_cluster[cid].item() >= 1
|
| 449 |
+
# Mean position
|
| 450 |
+
pos_mean = g.ndata["pos_hits_xyz"][mask_c].mean(dim=0).cpu().numpy()
|
| 451 |
+
rows.append({
|
| 452 |
+
"cluster_id": int(cid),
|
| 453 |
+
"energy_sum_hits": round(e_sum, 4),
|
| 454 |
+
"p_track": round(p_sum, 4) if has_track else 0.0,
|
| 455 |
+
"n_hits": n_h,
|
| 456 |
+
"is_charged": has_track,
|
| 457 |
+
"mean_x": round(float(pos_mean[0]), 2),
|
| 458 |
+
"mean_y": round(float(pos_mean[1]), 2),
|
| 459 |
+
"mean_z": round(float(pos_mean[2]), 2),
|
| 460 |
+
})
|
| 461 |
+
|
| 462 |
+
particles_df = pd.DataFrame(rows)
|
| 463 |
+
|
| 464 |
+
# 6. If energy correction is available, run it
|
| 465 |
+
if args.correction and hasattr(model, "energy_correction"):
|
| 466 |
+
try:
|
| 467 |
+
particles_df = _run_energy_correction(
|
| 468 |
+
model, g, x_cluster_coord, beta, labels, y_data, particles_df, dev
|
| 469 |
+
)
|
| 470 |
+
except Exception as e:
|
| 471 |
+
# Attach a note but don't crash – the basic table is still useful
|
| 472 |
+
particles_df["note"] = f"Energy correction failed: {e}"
|
| 473 |
+
|
| 474 |
+
return particles_df, hit_cluster_df, mc_particles_df, pandora_particles_df
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def _extract_mc_particles(event):
|
| 478 |
+
"""Build a DataFrame of MC truth particles from the event's ``X_gen``."""
|
| 479 |
+
x_gen = np.asarray(event.get("X_gen", []))
|
| 480 |
+
if x_gen.ndim != 2 or x_gen.shape[0] == 0 or x_gen.shape[1] < 18:
|
| 481 |
+
return pd.DataFrame()
|
| 482 |
+
|
| 483 |
+
rows = []
|
| 484 |
+
for i in range(x_gen.shape[0]):
|
| 485 |
+
pid_raw = int(x_gen[i, 0])
|
| 486 |
+
rows.append({
|
| 487 |
+
"particle_idx": i,
|
| 488 |
+
"pid": pid_raw,
|
| 489 |
+
"pdg_name": _ABS_PDG_NAME.get(abs(pid_raw), str(pid_raw)),
|
| 490 |
+
"gen_status": int(x_gen[i, 1]),
|
| 491 |
+
"energy": round(float(x_gen[i, 8]), 4),
|
| 492 |
+
"momentum": round(float(x_gen[i, 11]), 4),
|
| 493 |
+
"px": round(float(x_gen[i, 12]), 4),
|
| 494 |
+
"py": round(float(x_gen[i, 13]), 4),
|
| 495 |
+
"pz": round(float(x_gen[i, 14]), 4),
|
| 496 |
+
"mass": round(float(x_gen[i, 10]), 4),
|
| 497 |
+
"theta": round(float(x_gen[i, 4]), 4),
|
| 498 |
+
"phi": round(float(x_gen[i, 5]), 4),
|
| 499 |
+
"vx": round(float(x_gen[i, 15]), 4),
|
| 500 |
+
"vy": round(float(x_gen[i, 16]), 4),
|
| 501 |
+
"vz": round(float(x_gen[i, 17]), 4),
|
| 502 |
+
})
|
| 503 |
+
return pd.DataFrame(rows)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def _extract_pandora_particles(event):
|
| 507 |
+
"""Build a DataFrame of Pandora PFO particles from the event's ``X_pandora``.
|
| 508 |
+
|
| 509 |
+
``X_pandora`` columns (per PFO):
|
| 510 |
+
0: pid (PDG ID)
|
| 511 |
+
1–3: px, py, pz (momentum components at reference point)
|
| 512 |
+
4–6: ref_x, ref_y, ref_z (reference point)
|
| 513 |
+
7: energy
|
| 514 |
+
8: momentum magnitude
|
| 515 |
+
|
| 516 |
+
Returns (pandora_particles_df, pfo_hit_links, pfo_track_links) where
|
| 517 |
+
*pfo_hit_links* and *pfo_track_links* are integer arrays mapping each
|
| 518 |
+
hit/track to a PFO index (0-based, -1 = unassigned).
|
| 519 |
+
"""
|
| 520 |
+
x_pandora = np.asarray(event.get("X_pandora", []))
|
| 521 |
+
pfo_calohit = np.asarray(event.get("pfo_calohit", []), dtype=np.int64)
|
| 522 |
+
pfo_track = np.asarray(event.get("pfo_track", []), dtype=np.int64)
|
| 523 |
+
|
| 524 |
+
if x_pandora.ndim != 2 or x_pandora.shape[0] == 0 or x_pandora.shape[1] < 9:
|
| 525 |
+
return pd.DataFrame(), pfo_calohit, pfo_track
|
| 526 |
+
|
| 527 |
+
rows = []
|
| 528 |
+
for i in range(x_pandora.shape[0]):
|
| 529 |
+
pid_raw = int(x_pandora[i, 0])
|
| 530 |
+
rows.append({
|
| 531 |
+
"pfo_idx": i,
|
| 532 |
+
"pid": pid_raw,
|
| 533 |
+
"pdg_name": _ABS_PDG_NAME.get(abs(pid_raw), str(pid_raw)),
|
| 534 |
+
"energy": round(float(x_pandora[i, 7]), 4),
|
| 535 |
+
"momentum": round(float(x_pandora[i, 8]), 4),
|
| 536 |
+
"px": round(float(x_pandora[i, 1]), 4),
|
| 537 |
+
"py": round(float(x_pandora[i, 2]), 4),
|
| 538 |
+
"pz": round(float(x_pandora[i, 3]), 4),
|
| 539 |
+
"ref_x": round(float(x_pandora[i, 4]), 2),
|
| 540 |
+
"ref_y": round(float(x_pandora[i, 5]), 2),
|
| 541 |
+
"ref_z": round(float(x_pandora[i, 6]), 2),
|
| 542 |
+
})
|
| 543 |
+
return pd.DataFrame(rows), pfo_calohit, pfo_track
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def _run_energy_correction(model, g, x_cluster_coord, beta, labels, y_data, particles_df, dev):
|
| 547 |
+
"""Run the energy correction & PID branch and enrich *particles_df*."""
|
| 548 |
+
from src.layers.shower_matching import match_showers, obtain_intersection_matrix, obtain_union_matrix
|
| 549 |
+
from torch_scatter import scatter_add, scatter_mean
|
| 550 |
+
from src.utils.post_clustering_features import (
|
| 551 |
+
get_post_clustering_features, get_extra_features, calculate_eta, calculate_phi,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
x = torch.cat((x_cluster_coord, beta.view(-1, 1)), dim=1)
|
| 555 |
+
|
| 556 |
+
# Re-create per-cluster sub-graphs expected by the correction pipeline
|
| 557 |
+
particle_ids = torch.unique(g.ndata["particle_number"])
|
| 558 |
+
shower_p_unique = torch.unique(labels)
|
| 559 |
+
model_output_dummy = x # used only for device by match_showers
|
| 560 |
+
|
| 561 |
+
shower_p_unique_m, row_ind, col_ind, i_m_w, _ = match_showers(
|
| 562 |
+
labels, {"graph": g, "part_true": y_data},
|
| 563 |
+
particle_ids, model_output_dummy, 0, 0, None,
|
| 564 |
+
)
|
| 565 |
+
row_ind = torch.Tensor(row_ind).to(dev).long()
|
| 566 |
+
col_ind = torch.Tensor(col_ind).to(dev).long()
|
| 567 |
+
if torch.sum(particle_ids == 0) > 0:
|
| 568 |
+
row_ind_ = row_ind - 1
|
| 569 |
+
else:
|
| 570 |
+
row_ind_ = row_ind
|
| 571 |
+
index_matches = (col_ind + 1).to(dev).long()
|
| 572 |
+
|
| 573 |
+
# Build per-cluster sub-graphs (matched + fakes)
|
| 574 |
+
graphs_matched = []
|
| 575 |
+
true_energies = []
|
| 576 |
+
reco_energies = []
|
| 577 |
+
pids_matched = []
|
| 578 |
+
coords_matched = []
|
| 579 |
+
e_true_daughters = []
|
| 580 |
+
|
| 581 |
+
for j, sh_label in enumerate(index_matches):
|
| 582 |
+
if torch.sum(sh_label == index_matches) == 1:
|
| 583 |
+
mask = labels == sh_label
|
| 584 |
+
sg = dgl.graph(([], []))
|
| 585 |
+
sg.add_nodes(int(mask.sum()))
|
| 586 |
+
sg = sg.to(dev)
|
| 587 |
+
sg.ndata["h"] = g.ndata["h"][mask]
|
| 588 |
+
if "pos_pxpypz" in g.ndata:
|
| 589 |
+
sg.ndata["pos_pxpypz"] = g.ndata["pos_pxpypz"][mask]
|
| 590 |
+
if "pos_pxpypz_at_vertex" in g.ndata:
|
| 591 |
+
sg.ndata["pos_pxpypz_at_vertex"] = g.ndata["pos_pxpypz_at_vertex"][mask]
|
| 592 |
+
sg.ndata["chi_squared_tracks"] = g.ndata["chi_squared_tracks"][mask]
|
| 593 |
+
energy_t = y_data.E.to(dev)
|
| 594 |
+
true_e = energy_t[row_ind_[j]]
|
| 595 |
+
pids_matched.append(y_data.pid[row_ind_[j]].item())
|
| 596 |
+
coords_matched.append(y_data.coord[row_ind_[j]].detach().cpu().numpy())
|
| 597 |
+
e_true_daughters.append(y_data.m[row_ind_[j]].to(dev))
|
| 598 |
+
reco_e = torch.sum(g.ndata["e_hits"].view(-1).to(dev)[mask])
|
| 599 |
+
graphs_matched.append(sg)
|
| 600 |
+
true_energies.append(true_e.view(-1))
|
| 601 |
+
reco_energies.append(reco_e.view(-1))
|
| 602 |
+
|
| 603 |
+
# Add fakes
|
| 604 |
+
pred_showers = shower_p_unique_m.clone()
|
| 605 |
+
pred_showers[index_matches] = -1
|
| 606 |
+
pred_showers[0] = -1
|
| 607 |
+
fakes_mask = pred_showers != -1
|
| 608 |
+
fakes_idx = torch.where(fakes_mask)[0]
|
| 609 |
+
|
| 610 |
+
graphs_fakes = []
|
| 611 |
+
reco_fakes = []
|
| 612 |
+
for fi in fakes_idx:
|
| 613 |
+
mask = labels == fi
|
| 614 |
+
sg = dgl.graph(([], []))
|
| 615 |
+
sg.add_nodes(int(mask.sum()))
|
| 616 |
+
sg = sg.to(dev)
|
| 617 |
+
sg.ndata["h"] = g.ndata["h"][mask]
|
| 618 |
+
if "pos_pxpypz" in g.ndata:
|
| 619 |
+
sg.ndata["pos_pxpypz"] = g.ndata["pos_pxpypz"][mask]
|
| 620 |
+
if "pos_pxpypz_at_vertex" in g.ndata:
|
| 621 |
+
sg.ndata["pos_pxpypz_at_vertex"] = g.ndata["pos_pxpypz_at_vertex"][mask]
|
| 622 |
+
sg.ndata["chi_squared_tracks"] = g.ndata["chi_squared_tracks"][mask]
|
| 623 |
+
graphs_fakes.append(sg)
|
| 624 |
+
reco_fakes.append(torch.sum(g.ndata["e_hits"].view(-1).to(dev)[mask]).view(-1))
|
| 625 |
+
|
| 626 |
+
if not graphs_matched and not graphs_fakes:
|
| 627 |
+
return particles_df
|
| 628 |
+
|
| 629 |
+
all_graphs = dgl.batch(graphs_matched + graphs_fakes)
|
| 630 |
+
sum_e = torch.cat(reco_energies + reco_fakes, dim=0)
|
| 631 |
+
|
| 632 |
+
# Compute high-level features
|
| 633 |
+
batch_num_nodes = all_graphs.batch_num_nodes()
|
| 634 |
+
batch_idx = []
|
| 635 |
+
for i, n in enumerate(batch_num_nodes):
|
| 636 |
+
batch_idx.extend([i] * n)
|
| 637 |
+
batch_idx = torch.tensor(batch_idx).to(dev)
|
| 638 |
+
|
| 639 |
+
all_graphs.ndata["h"][:, 0:3] = all_graphs.ndata["h"][:, 0:3] / 3300
|
| 640 |
+
graphs_sum_features = scatter_add(all_graphs.ndata["h"], batch_idx, dim=0)
|
| 641 |
+
graphs_sum_features = graphs_sum_features[batch_idx]
|
| 642 |
+
betas = torch.sigmoid(all_graphs.ndata["h"][:, -1])
|
| 643 |
+
all_graphs.ndata["h"] = torch.cat(
|
| 644 |
+
(all_graphs.ndata["h"], graphs_sum_features), dim=1
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
high_level = get_post_clustering_features(all_graphs, sum_e)
|
| 648 |
+
extra_features = get_extra_features(all_graphs, betas)
|
| 649 |
+
|
| 650 |
+
n_clusters = high_level.shape[0]
|
| 651 |
+
pred_energy = torch.ones(n_clusters, device=dev)
|
| 652 |
+
pred_pos = torch.ones(n_clusters, 3, device=dev)
|
| 653 |
+
pred_pid = torch.ones(n_clusters, device=dev).long()
|
| 654 |
+
|
| 655 |
+
node_features_avg = scatter_mean(all_graphs.ndata["h"], batch_idx, dim=0)[:, 0:3]
|
| 656 |
+
eta = calculate_eta(node_features_avg[:, 0], node_features_avg[:, 1], node_features_avg[:, 2])
|
| 657 |
+
phi = calculate_phi(node_features_avg[:, 0], node_features_avg[:, 1])
|
| 658 |
+
high_level = torch.cat(
|
| 659 |
+
(high_level, node_features_avg, eta.view(-1, 1), phi.view(-1, 1)), dim=1
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
num_tracks = high_level[:, 7]
|
| 663 |
+
charged_idx = torch.where(num_tracks >= 1)[0]
|
| 664 |
+
neutral_idx = torch.where(num_tracks < 1)[0]
|
| 665 |
+
|
| 666 |
+
def zero_nans(t):
|
| 667 |
+
out = t.clone()
|
| 668 |
+
out[out != out] = 0
|
| 669 |
+
return out
|
| 670 |
+
|
| 671 |
+
feats_charged = zero_nans(high_level[charged_idx])
|
| 672 |
+
feats_neutral = zero_nans(high_level[neutral_idx])
|
| 673 |
+
|
| 674 |
+
# Run charged prediction
|
| 675 |
+
charged_energies = model.energy_correction.model_charged.charged_prediction(
|
| 676 |
+
all_graphs, charged_idx, feats_charged,
|
| 677 |
+
)
|
| 678 |
+
# Run neutral prediction
|
| 679 |
+
neutral_energies, neutral_pxyz_avg = model.energy_correction.model_neutral.neutral_prediction(
|
| 680 |
+
all_graphs, neutral_idx, feats_neutral,
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
pids_charged = model.energy_correction.pids_charged
|
| 684 |
+
pids_neutral = model.energy_correction.pids_neutral
|
| 685 |
+
|
| 686 |
+
if len(pids_charged):
|
| 687 |
+
ch_e, ch_pos, ch_pid_logits, ch_ref = charged_energies
|
| 688 |
+
else:
|
| 689 |
+
ch_e, ch_pos, _ = charged_energies
|
| 690 |
+
ch_pid_logits = None
|
| 691 |
+
|
| 692 |
+
if len(pids_neutral):
|
| 693 |
+
ne_e, ne_pos, ne_pid_logits, ne_ref = neutral_energies
|
| 694 |
+
else:
|
| 695 |
+
ne_e, ne_pos, _ = neutral_energies
|
| 696 |
+
ne_pid_logits = None
|
| 697 |
+
|
| 698 |
+
pred_energy[charged_idx.flatten()] = ch_e if len(charged_idx) else pred_energy[charged_idx.flatten()]
|
| 699 |
+
pred_energy[neutral_idx.flatten()] = ne_e if len(neutral_idx) else pred_energy[neutral_idx.flatten()]
|
| 700 |
+
|
| 701 |
+
if ch_pid_logits is not None and len(charged_idx):
|
| 702 |
+
ch_labels = np.array(pids_charged)[np.argmax(ch_pid_logits.cpu().detach().numpy(), axis=1)]
|
| 703 |
+
pred_pid[charged_idx.flatten()] = torch.tensor(ch_labels).long().to(dev)
|
| 704 |
+
if ne_pid_logits is not None and len(neutral_idx):
|
| 705 |
+
ne_labels = np.array(pids_neutral)[np.argmax(ne_pid_logits.cpu().detach().numpy(), axis=1)]
|
| 706 |
+
pred_pid[neutral_idx.flatten()] = torch.tensor(ne_labels).long().to(dev)
|
| 707 |
+
|
| 708 |
+
pred_energy[pred_energy < 0] = 0.0
|
| 709 |
+
|
| 710 |
+
# Direction
|
| 711 |
+
if len(charged_idx):
|
| 712 |
+
pred_pos[charged_idx.flatten()] = ch_pos.float().to(dev)
|
| 713 |
+
if len(neutral_idx):
|
| 714 |
+
pred_pos[neutral_idx.flatten()] = ne_pos.float().to(dev)
|
| 715 |
+
|
| 716 |
+
# Build enriched output DataFrame
|
| 717 |
+
n_matched = len(graphs_matched)
|
| 718 |
+
rows = []
|
| 719 |
+
for k in range(n_clusters):
|
| 720 |
+
is_fake = k >= n_matched
|
| 721 |
+
pid_cls = int(pred_pid[k].item())
|
| 722 |
+
rows.append({
|
| 723 |
+
"cluster_id": k + 1,
|
| 724 |
+
"corrected_energy": round(pred_energy[k].item(), 4),
|
| 725 |
+
"raw_energy": round(sum_e[k].item(), 4),
|
| 726 |
+
"pid_class": pid_cls,
|
| 727 |
+
"pid_label": _PID_LABELS.get(pid_cls, str(pid_cls)),
|
| 728 |
+
"px": round(pred_pos[k, 0].item(), 4),
|
| 729 |
+
"py": round(pred_pos[k, 1].item(), 4),
|
| 730 |
+
"pz": round(pred_pos[k, 2].item(), 4),
|
| 731 |
+
"is_charged": bool(k in charged_idx),
|
| 732 |
+
"is_fake": is_fake,
|
| 733 |
+
})
|
| 734 |
+
|
| 735 |
+
return pd.DataFrame(rows)
|
src/layers/clustering.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Clustering algorithms for particle-flow reconstruction.
|
| 2 |
+
|
| 3 |
+
Adapted from densitypeakclustering (https://github.com/lanbing510/DensityPeakCluster).
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from torch_scatter import scatter_add
|
| 8 |
+
import densitypeakclustering as dc
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def local_density_energy(D, d_c, energies, normalize=False):
|
| 12 |
+
D_cuttoff = D < d_c
|
| 13 |
+
rho = np.zeros((D.shape[0],))
|
| 14 |
+
for s in range(len(rho)):
|
| 15 |
+
rho[s] = np.sum(energies[D_cuttoff[s, :]] * np.exp(-(D[s, D_cuttoff[s, :]] / d_c) ** 2))
|
| 16 |
+
if normalize:
|
| 17 |
+
rho = rho / np.max(rho)
|
| 18 |
+
return rho
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def DPC_custom_CLD(X, g, device):
|
| 22 |
+
d_c = 0.1
|
| 23 |
+
rho_min = 0.05
|
| 24 |
+
delta_min = 0.4
|
| 25 |
+
D = dc.distance_matrix(X.detach().cpu())
|
| 26 |
+
rho = local_density_energy(D, d_c, g.ndata["e_hits"].view(-1).cpu().numpy())
|
| 27 |
+
delta, nearest = dc.distance_to_larger_density(D, rho)
|
| 28 |
+
centers = dc.cluster_centers(rho, delta, rho_min=rho_min, delta_min=delta_min)
|
| 29 |
+
ids = dc.assign_cluster_id(rho, nearest, centers)
|
| 30 |
+
core_ids = np.full(len(X), -1)
|
| 31 |
+
D[np.isnan(D)] = 0
|
| 32 |
+
for indx, c in enumerate(centers):
|
| 33 |
+
idx = np.where((ids == indx) & (D[:, c] < 0.5))[0]
|
| 34 |
+
core_ids[idx] = indx
|
| 35 |
+
labels = torch.Tensor(core_ids) + 1
|
| 36 |
+
return labels.long().to(device)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def remove_bad_tracks_from_cluster(g, labels_hdb):
|
| 40 |
+
mask_hit_type_t1 = g.ndata["hit_type"] == 2
|
| 41 |
+
mask_hit_type_t2 = g.ndata["hit_type"] == 1
|
| 42 |
+
mask_hit_type_t4 = g.ndata["hit_type"] == 4
|
| 43 |
+
labels_hdb_corrected_tracks = labels_hdb.clone()
|
| 44 |
+
labels_changed_tracks = 0.0 * (labels_hdb.clone())
|
| 45 |
+
for i in range(0, torch.max(labels_hdb) + 1):
|
| 46 |
+
mask_labels_i = labels_hdb == i
|
| 47 |
+
if torch.sum(mask_hit_type_t2[mask_labels_i]) > 0 and i > 0:
|
| 48 |
+
e_cluster = torch.sum(g.ndata["e_hits"][mask_labels_i])
|
| 49 |
+
p_track = g.ndata["p_hits"][mask_labels_i * mask_hit_type_t2]
|
| 50 |
+
number_of_hits_muon = torch.sum(mask_labels_i * mask_hit_type_t4)
|
| 51 |
+
diffs = torch.abs(e_cluster - p_track) / p_track
|
| 52 |
+
diffs = diffs.view(-1)
|
| 53 |
+
sigma_4 = 4 * 0.5 / torch.sqrt(p_track).view(-1)
|
| 54 |
+
bad_diffs = diffs > sigma_4
|
| 55 |
+
bad_tracks = bad_diffs * (number_of_hits_muon < 1)
|
| 56 |
+
cluster_t2_nodes = torch.nonzero(mask_labels_i & mask_hit_type_t2).view(-1)
|
| 57 |
+
bad_tracks_nodes = cluster_t2_nodes[bad_tracks]
|
| 58 |
+
labels_hdb_corrected_tracks[bad_tracks_nodes] = 0
|
| 59 |
+
if torch.sum(bad_tracks_nodes) > 0:
|
| 60 |
+
labels_changed_tracks[mask_labels_i] = 1
|
| 61 |
+
return labels_hdb_corrected_tracks, labels_changed_tracks
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def remove_labels_of_double_showers(labels, g):
|
| 65 |
+
is_track_per_shower = scatter_add(1 * (g.ndata["hit_type"] == 1), labels).int()
|
| 66 |
+
e_hits_sum = scatter_add(g.ndata["e_hits"].view(-1), labels.view(-1).long()).int()
|
| 67 |
+
mask_tracks = g.ndata["hit_type"] == 1
|
| 68 |
+
for i, label_i in enumerate(torch.unique(labels)):
|
| 69 |
+
if is_track_per_shower[label_i] == 2:
|
| 70 |
+
if label_i > 0:
|
| 71 |
+
sum_pred_2 = e_hits_sum[label_i]
|
| 72 |
+
mask_labels_i = labels == label_i
|
| 73 |
+
mask_label_i_and_is_track = mask_labels_i * mask_tracks
|
| 74 |
+
tracks_E = g.ndata['h'][:, -1][mask_label_i_and_is_track]
|
| 75 |
+
chi_tracks = g.ndata['chi_squared_tracks'][mask_label_i_and_is_track]
|
| 76 |
+
ind_min_E = torch.argmax(torch.abs(tracks_E - sum_pred_2))
|
| 77 |
+
ind_min_chi = torch.argmax(chi_tracks)
|
| 78 |
+
mask_hit_type_t1 = g.ndata["hit_type"][mask_labels_i] == 2
|
| 79 |
+
mask_hit_type_t2 = g.ndata["hit_type"][mask_labels_i] == 1
|
| 80 |
+
mask_all = mask_hit_type_t1
|
| 81 |
+
index_sorted = torch.argsort(g.ndata["radial_distance"][mask_labels_i][mask_hit_type_t1])
|
| 82 |
+
mask_sorted_ind = index_sorted < 10
|
| 83 |
+
mean_pos_cluster = torch.mean(
|
| 84 |
+
g.ndata["pos_hits_xyz"][mask_labels_i][mask_all][mask_sorted_ind], dim=0
|
| 85 |
+
)
|
| 86 |
+
pos_track = g.ndata["pos_hits_xyz"][mask_labels_i][mask_hit_type_t2]
|
| 87 |
+
distance_track_cluster = torch.norm(pos_track - mean_pos_cluster, dim=1) / 1000
|
| 88 |
+
ind_max_dtc = torch.argmax(distance_track_cluster)
|
| 89 |
+
if torch.min(distance_track_cluster) < 0.4:
|
| 90 |
+
ind_min = ind_max_dtc
|
| 91 |
+
elif ind_min_E == ind_min_chi:
|
| 92 |
+
ind_min = ind_min_E
|
| 93 |
+
elif torch.max(chi_tracks - torch.min(chi_tracks)) < 2:
|
| 94 |
+
ind_min = ind_min_E
|
| 95 |
+
else:
|
| 96 |
+
ind_min = ind_min_chi
|
| 97 |
+
ind_change = torch.argwhere(mask_label_i_and_is_track)[ind_min]
|
| 98 |
+
labels[ind_change] = 0
|
| 99 |
+
return labels
|
src/layers/inference_oc.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file includes code adapted from:
|
| 3 |
+
|
| 4 |
+
densitypeakclustering
|
| 5 |
+
https://github.com/lanbing510/DensityPeakCluster
|
| 6 |
+
|
| 7 |
+
The original implementation has been modified and integrated into this project.
|
| 8 |
+
Please refer to the original repository for authorship, documentation,
|
| 9 |
+
and license information.
|
| 10 |
+
"""
|
| 11 |
+
import dgl
|
| 12 |
+
import torch
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import numpy as np
|
| 15 |
+
import wandb
|
| 16 |
+
|
| 17 |
+
from src.layers.clustering import (
|
| 18 |
+
local_density_energy,
|
| 19 |
+
DPC_custom_CLD,
|
| 20 |
+
remove_bad_tracks_from_cluster,
|
| 21 |
+
remove_labels_of_double_showers,
|
| 22 |
+
)
|
| 23 |
+
from src.layers.shower_matching import (
|
| 24 |
+
CachedIndexList,
|
| 25 |
+
get_labels_pandora,
|
| 26 |
+
obtain_intersection_matrix,
|
| 27 |
+
obtain_union_matrix,
|
| 28 |
+
obtain_intersection_values,
|
| 29 |
+
match_showers,
|
| 30 |
+
)
|
| 31 |
+
from src.layers.shower_dataframe import (
|
| 32 |
+
get_correction_per_shower,
|
| 33 |
+
distance_to_true_cluster_of_track,
|
| 34 |
+
distance_to_cluster_track,
|
| 35 |
+
generate_showers_data_frame,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Re-export everything so existing callers (utils_training, Gatr_pf_e_noise, …)
|
| 39 |
+
# that do `from src.layers.inference_oc import X` continue to work unchanged.
|
| 40 |
+
__all__ = [
|
| 41 |
+
"local_density_energy",
|
| 42 |
+
"DPC_custom_CLD",
|
| 43 |
+
"remove_bad_tracks_from_cluster",
|
| 44 |
+
"remove_labels_of_double_showers",
|
| 45 |
+
"CachedIndexList",
|
| 46 |
+
"get_labels_pandora",
|
| 47 |
+
"obtain_intersection_matrix",
|
| 48 |
+
"obtain_union_matrix",
|
| 49 |
+
"obtain_intersection_values",
|
| 50 |
+
"match_showers",
|
| 51 |
+
"get_correction_per_shower",
|
| 52 |
+
"distance_to_true_cluster_of_track",
|
| 53 |
+
"distance_to_cluster_track",
|
| 54 |
+
"generate_showers_data_frame",
|
| 55 |
+
"log_efficiency",
|
| 56 |
+
"store_at_batch_end",
|
| 57 |
+
"create_and_store_graph_output",
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def log_efficiency(df, pandora=False, clustering=False):
|
| 62 |
+
mask = ~np.isnan(df["reco_showers_E"])
|
| 63 |
+
eff = np.sum(~np.isnan(df["pred_showers_E"][mask].values)) / len(
|
| 64 |
+
df["pred_showers_E"][mask].values
|
| 65 |
+
)
|
| 66 |
+
if pandora:
|
| 67 |
+
wandb.log({"efficiency validation pandora": eff})
|
| 68 |
+
elif clustering:
|
| 69 |
+
wandb.log({"efficiency validation clustering": eff})
|
| 70 |
+
else:
|
| 71 |
+
wandb.log({"efficiency validation": eff})
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _make_save_path(path_save, local_rank, step, epoch, suffix=""):
|
| 75 |
+
return path_save + str(local_rank) + "_" + str(step) + "_" + str(epoch) + suffix + ".pt"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def store_at_batch_end(
|
| 79 |
+
path_save,
|
| 80 |
+
df_batch1,
|
| 81 |
+
df_batch_pandora,
|
| 82 |
+
local_rank=0,
|
| 83 |
+
step=0,
|
| 84 |
+
epoch=None,
|
| 85 |
+
predict=False,
|
| 86 |
+
store=False,
|
| 87 |
+
pandora_available=False,
|
| 88 |
+
):
|
| 89 |
+
path_save_ = _make_save_path(path_save, local_rank, step, epoch)
|
| 90 |
+
if store and predict:
|
| 91 |
+
df_batch1.to_pickle(path_save_)
|
| 92 |
+
if predict and pandora_available:
|
| 93 |
+
path_save_pandora = _make_save_path(path_save, local_rank, step, epoch, "_pandora")
|
| 94 |
+
if store and predict:
|
| 95 |
+
df_batch_pandora.to_pickle(path_save_pandora)
|
| 96 |
+
log_efficiency(df_batch1)
|
| 97 |
+
if predict and pandora_available:
|
| 98 |
+
log_efficiency(df_batch_pandora, pandora=True)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def create_and_store_graph_output(
|
| 102 |
+
batch_g,
|
| 103 |
+
model_output,
|
| 104 |
+
y,
|
| 105 |
+
local_rank,
|
| 106 |
+
step,
|
| 107 |
+
epoch,
|
| 108 |
+
path_save,
|
| 109 |
+
store=False,
|
| 110 |
+
predict=False,
|
| 111 |
+
e_corr=None,
|
| 112 |
+
ec_x=None,
|
| 113 |
+
store_epoch=False,
|
| 114 |
+
total_number_events=0,
|
| 115 |
+
pred_pos=None,
|
| 116 |
+
pred_ref_pt=None,
|
| 117 |
+
use_gt_clusters=False,
|
| 118 |
+
pred_pid=None,
|
| 119 |
+
number_of_fakes=None,
|
| 120 |
+
extra_features=None,
|
| 121 |
+
fakes_labels=None,
|
| 122 |
+
pandora_available=False,
|
| 123 |
+
truth_tracks=False,
|
| 124 |
+
):
|
| 125 |
+
number_of_showers_total = 0
|
| 126 |
+
number_of_showers_total1 = 0
|
| 127 |
+
number_of_fake_showers_total1 = 0
|
| 128 |
+
batch_g.ndata["coords"] = model_output[:, 0:3]
|
| 129 |
+
batch_g.ndata["beta"] = model_output[:, 3]
|
| 130 |
+
if e_corr is None:
|
| 131 |
+
batch_g.ndata["correction"] = model_output[:, 4]
|
| 132 |
+
graphs = dgl.unbatch(batch_g)
|
| 133 |
+
batch_id = y.batch_number.view(-1)
|
| 134 |
+
df_list1 = []
|
| 135 |
+
df_list_pandora = []
|
| 136 |
+
for i in range(0, len(graphs)):
|
| 137 |
+
mask = batch_id == i
|
| 138 |
+
dic = {}
|
| 139 |
+
dic["graph"] = graphs[i]
|
| 140 |
+
y1 = y.copy()
|
| 141 |
+
y1.mask(mask)
|
| 142 |
+
dic["part_true"] = y1
|
| 143 |
+
X = dic["graph"].ndata["coords"]
|
| 144 |
+
labels_clusters_removed_tracks = torch.zeros(
|
| 145 |
+
dic["graph"].num_nodes(), device=model_output.device
|
| 146 |
+
)
|
| 147 |
+
if use_gt_clusters:
|
| 148 |
+
labels_hdb = dic["graph"].ndata["particle_number"].type(torch.int64)
|
| 149 |
+
else:
|
| 150 |
+
labels_hdb = DPC_custom_CLD(X, dic["graph"], model_output.device)
|
| 151 |
+
if not truth_tracks:
|
| 152 |
+
labels_hdb, labels_clusters_removed_tracks = remove_bad_tracks_from_cluster(
|
| 153 |
+
dic["graph"], labels_hdb
|
| 154 |
+
)
|
| 155 |
+
if predict and pandora_available:
|
| 156 |
+
labels_pandora = get_labels_pandora(dic, model_output.device)
|
| 157 |
+
particle_ids = torch.unique(dic["graph"].ndata["particle_number"])
|
| 158 |
+
|
| 159 |
+
shower_p_unique_hdb, row_ind_hdb, col_ind_hdb, i_m_w_hdb, iou_m = match_showers(
|
| 160 |
+
labels_hdb,
|
| 161 |
+
dic,
|
| 162 |
+
particle_ids,
|
| 163 |
+
model_output,
|
| 164 |
+
local_rank,
|
| 165 |
+
i,
|
| 166 |
+
path_save,
|
| 167 |
+
hdbscan=True,
|
| 168 |
+
)
|
| 169 |
+
if predict and pandora_available:
|
| 170 |
+
(
|
| 171 |
+
shower_p_unique_pandora,
|
| 172 |
+
row_ind_pandora,
|
| 173 |
+
col_ind_pandora,
|
| 174 |
+
i_m_w_pandora,
|
| 175 |
+
iou_m_pandora,
|
| 176 |
+
) = match_showers(
|
| 177 |
+
labels_pandora,
|
| 178 |
+
dic,
|
| 179 |
+
particle_ids,
|
| 180 |
+
model_output,
|
| 181 |
+
local_rank,
|
| 182 |
+
i,
|
| 183 |
+
path_save,
|
| 184 |
+
pandora=True,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
if len(shower_p_unique_hdb) > 1:
|
| 188 |
+
df_event1, number_of_showers_total1, number_of_fake_showers_total1 = generate_showers_data_frame(
|
| 189 |
+
labels_hdb,
|
| 190 |
+
dic,
|
| 191 |
+
shower_p_unique_hdb,
|
| 192 |
+
particle_ids,
|
| 193 |
+
row_ind_hdb,
|
| 194 |
+
col_ind_hdb,
|
| 195 |
+
i_m_w_hdb,
|
| 196 |
+
e_corr=e_corr,
|
| 197 |
+
number_of_showers_total=number_of_showers_total1,
|
| 198 |
+
step=step,
|
| 199 |
+
number_in_batch=total_number_events,
|
| 200 |
+
ec_x=ec_x,
|
| 201 |
+
pred_pos=pred_pos,
|
| 202 |
+
pred_ref_pt=pred_ref_pt,
|
| 203 |
+
pred_pid=pred_pid,
|
| 204 |
+
number_of_fakes=number_of_fakes,
|
| 205 |
+
number_of_fake_showers_total=number_of_fake_showers_total1,
|
| 206 |
+
extra_features=extra_features,
|
| 207 |
+
labels_clusters_removed_tracks=labels_clusters_removed_tracks,
|
| 208 |
+
)
|
| 209 |
+
if len(df_event1) > 1:
|
| 210 |
+
df_list1.append(df_event1)
|
| 211 |
+
if predict and pandora_available:
|
| 212 |
+
df_event_pandora = generate_showers_data_frame(
|
| 213 |
+
labels_pandora,
|
| 214 |
+
dic,
|
| 215 |
+
shower_p_unique_pandora,
|
| 216 |
+
particle_ids,
|
| 217 |
+
row_ind_pandora,
|
| 218 |
+
col_ind_pandora,
|
| 219 |
+
i_m_w_pandora,
|
| 220 |
+
pandora=True,
|
| 221 |
+
step=step,
|
| 222 |
+
number_in_batch=total_number_events,
|
| 223 |
+
)
|
| 224 |
+
if df_event_pandora is not None and type(df_event_pandora) is not tuple:
|
| 225 |
+
df_list_pandora.append(df_event_pandora)
|
| 226 |
+
else:
|
| 227 |
+
print("Not appending to df_list_pandora")
|
| 228 |
+
total_number_events = total_number_events + 1
|
| 229 |
+
|
| 230 |
+
df_batch1 = pd.concat(df_list1)
|
| 231 |
+
if predict and pandora_available:
|
| 232 |
+
df_batch_pandora = pd.concat(df_list_pandora)
|
| 233 |
+
else:
|
| 234 |
+
df_batch = []
|
| 235 |
+
df_batch_pandora = []
|
| 236 |
+
if store:
|
| 237 |
+
store_at_batch_end(
|
| 238 |
+
path_save,
|
| 239 |
+
df_batch1,
|
| 240 |
+
df_batch_pandora,
|
| 241 |
+
local_rank,
|
| 242 |
+
step,
|
| 243 |
+
epoch,
|
| 244 |
+
predict=predict,
|
| 245 |
+
store=store_epoch,
|
| 246 |
+
pandora_available=pandora_available,
|
| 247 |
+
)
|
| 248 |
+
if predict:
|
| 249 |
+
return df_batch_pandora, df_batch1, total_number_events
|
| 250 |
+
else:
|
| 251 |
+
return df_batch1
|
src/layers/object_cond.py
ADDED
|
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The loss implementation in this file is adapted from the HGCalML repository:
|
| 3 |
+
|
| 4 |
+
Repository: https://github.com/jkiesele/HGCalML
|
| 5 |
+
File: modules/lossLayers.py
|
| 6 |
+
|
| 7 |
+
Original author: Jan Kieseler
|
| 8 |
+
License: See the original repository for license details.
|
| 9 |
+
|
| 10 |
+
The implementation has been modified and integrated into this project.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from typing import Tuple, Union
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
from torch_scatter import scatter_max, scatter_add, scatter_mean
|
| 17 |
+
import dgl
|
| 18 |
+
|
| 19 |
+
def safe_index(arr, index):
|
| 20 |
+
# One-hot index (or zero if it's not in the array)
|
| 21 |
+
if index not in arr:
|
| 22 |
+
return 0
|
| 23 |
+
else:
|
| 24 |
+
return arr.index(index) + 1
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def assert_no_nans(x):
|
| 28 |
+
"""
|
| 29 |
+
Raises AssertionError if there is a nan in the tensor
|
| 30 |
+
"""
|
| 31 |
+
if torch.isnan(x).any():
|
| 32 |
+
print(x)
|
| 33 |
+
assert not torch.isnan(x).any()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def calc_LV_Lbeta(
|
| 37 |
+
original_coords,
|
| 38 |
+
g,
|
| 39 |
+
y,
|
| 40 |
+
distance_threshold,
|
| 41 |
+
energy_correction,
|
| 42 |
+
beta: torch.Tensor,
|
| 43 |
+
cluster_space_coords: torch.Tensor, # Predicted by model
|
| 44 |
+
cluster_index_per_event: torch.Tensor, # Truth hit->cluster index
|
| 45 |
+
batch: torch.Tensor,
|
| 46 |
+
predicted_pid=None, # predicted PID embeddings - will be aggregated by summing up the clusters and applying the post_pid_pool_module MLP afterwards
|
| 47 |
+
# From here on just parameters
|
| 48 |
+
qmin: float = 0.1,
|
| 49 |
+
s_B: float = 1.0,
|
| 50 |
+
noise_cluster_index: int = 0, # cluster_index entries with this value are noise/noise
|
| 51 |
+
frac_combinations=0, # fraction of the all possible pairs to be used for the clustering loss
|
| 52 |
+
use_average_cc_pos=0.0,
|
| 53 |
+
loss_type="hgcalimplementation",
|
| 54 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], dict]:
|
| 55 |
+
"""
|
| 56 |
+
Calculates the L_V and L_beta object condensation losses.
|
| 57 |
+
Concepts:
|
| 58 |
+
- A hit belongs to exactly one cluster (cluster_index_per_event is (n_hits,)),
|
| 59 |
+
and to exactly one event (batch is (n_hits,))
|
| 60 |
+
- A cluster index of `noise_cluster_index` means the cluster is a noise cluster.
|
| 61 |
+
There is typically one noise cluster per event. Any hit in a noise cluster
|
| 62 |
+
is a 'noise hit'. A hit in an object is called a 'signal hit' for lack of a
|
| 63 |
+
better term.
|
| 64 |
+
- An 'object' is a cluster that is *not* a noise cluster.
|
| 65 |
+
beta_stabilizing: Choices are ['paper', 'clip', 'soft_q_scaling']:
|
| 66 |
+
paper: beta is sigmoid(model_output), q = beta.arctanh()**2 + qmin
|
| 67 |
+
clip: beta is clipped to 1-1e-4, q = beta.arctanh()**2 + qmin
|
| 68 |
+
soft_q_scaling: beta is sigmoid(model_output), q = (clip(beta)/1.002).arctanh()**2 + qmin
|
| 69 |
+
huberize_norm_for_V_attractive: Huberizes the norms when used in the attractive potential
|
| 70 |
+
beta_term_option: Choices are ['paper', 'short-range-potential']:
|
| 71 |
+
Choosing 'short-range-potential' introduces a short range potential around high
|
| 72 |
+
beta points, acting like V_attractive.
|
| 73 |
+
Note this function has modifications w.r.t. the implementation in 2002.03605:
|
| 74 |
+
- The norms for V_repulsive are now Gaussian (instead of linear hinge)
|
| 75 |
+
"""
|
| 76 |
+
# remove dummy rows added for dataloader #TODO think of better way to do this
|
| 77 |
+
device = beta.device
|
| 78 |
+
if torch.isnan(beta).any():
|
| 79 |
+
print("There are nans in beta! L198", len(beta[torch.isnan(beta)]))
|
| 80 |
+
|
| 81 |
+
beta = torch.nan_to_num(beta, nan=0.0)
|
| 82 |
+
assert_no_nans(beta)
|
| 83 |
+
# ________________________________
|
| 84 |
+
|
| 85 |
+
# Calculate a bunch of needed counts and indices locally
|
| 86 |
+
|
| 87 |
+
# cluster_index: unique index over events
|
| 88 |
+
# E.g. cluster_index_per_event=[ 0, 0, 1, 2, 0, 0, 1], batch=[0, 0, 0, 0, 1, 1, 1]
|
| 89 |
+
# -> cluster_index=[ 0, 0, 1, 2, 3, 3, 4 ]
|
| 90 |
+
cluster_index, n_clusters_per_event = batch_cluster_indices(
|
| 91 |
+
cluster_index_per_event, batch
|
| 92 |
+
)
|
| 93 |
+
n_clusters = n_clusters_per_event.sum()
|
| 94 |
+
n_hits, cluster_space_dim = cluster_space_coords.size()
|
| 95 |
+
batch_size = batch.max() + 1
|
| 96 |
+
n_hits_per_event = scatter_count(batch)
|
| 97 |
+
|
| 98 |
+
# Index of cluster -> event (n_clusters,)
|
| 99 |
+
batch_cluster = scatter_counts_to_indices(n_clusters_per_event)
|
| 100 |
+
|
| 101 |
+
# Per-hit boolean, indicating whether hit is sig or noise
|
| 102 |
+
is_noise = cluster_index_per_event == noise_cluster_index
|
| 103 |
+
is_sig = ~is_noise
|
| 104 |
+
n_hits_sig = is_sig.sum()
|
| 105 |
+
n_sig_hits_per_event = scatter_count(batch[is_sig])
|
| 106 |
+
|
| 107 |
+
# Per-cluster boolean, indicating whether cluster is an object or noise
|
| 108 |
+
is_object = scatter_max(is_sig.long(), cluster_index)[0].bool()
|
| 109 |
+
is_noise_cluster = ~is_object
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if noise_cluster_index != 0:
|
| 113 |
+
raise NotImplementedError
|
| 114 |
+
object_index_per_event = cluster_index_per_event[is_sig] - 1
|
| 115 |
+
object_index, n_objects_per_event = batch_cluster_indices(
|
| 116 |
+
object_index_per_event, batch[is_sig]
|
| 117 |
+
)
|
| 118 |
+
n_hits_per_object = scatter_count(object_index)
|
| 119 |
+
# print("n_hits_per_object", n_hits_per_object)
|
| 120 |
+
batch_object = batch_cluster[is_object]
|
| 121 |
+
n_objects = is_object.sum()
|
| 122 |
+
|
| 123 |
+
assert object_index.size() == (n_hits_sig,)
|
| 124 |
+
assert is_object.size() == (n_clusters,)
|
| 125 |
+
assert torch.all(n_hits_per_object > 0)
|
| 126 |
+
assert object_index.max() + 1 == n_objects
|
| 127 |
+
|
| 128 |
+
# ________________________________
|
| 129 |
+
# L_V term
|
| 130 |
+
|
| 131 |
+
# Calculate q
|
| 132 |
+
q = (beta.clip(0.0, 1 - 1e-4).arctanh() / 1.01) ** 2 + qmin
|
| 133 |
+
assert_no_nans(q)
|
| 134 |
+
assert q.device == device
|
| 135 |
+
assert q.size() == (n_hits,)
|
| 136 |
+
|
| 137 |
+
# Calculate q_alpha, the max q per object, and the indices of said maxima
|
| 138 |
+
# assert hit_energies.shape == q.shape
|
| 139 |
+
# q_alpha, index_alpha = scatter_max(hit_energies[is_sig], object_index)
|
| 140 |
+
q_alpha, index_alpha = scatter_max(q[is_sig], object_index)
|
| 141 |
+
assert q_alpha.size() == (n_objects,)
|
| 142 |
+
|
| 143 |
+
# Get the cluster space coordinates and betas for these maxima hits too
|
| 144 |
+
x_alpha = cluster_space_coords[is_sig][index_alpha]
|
| 145 |
+
x_alpha_original = original_coords[is_sig][index_alpha]
|
| 146 |
+
if use_average_cc_pos > 0:
|
| 147 |
+
x_alpha_sum = scatter_add(
|
| 148 |
+
q[is_sig].view(-1, 1).repeat(1, 3) * cluster_space_coords[is_sig],
|
| 149 |
+
object_index,
|
| 150 |
+
dim=0,
|
| 151 |
+
) # * beta[is_sig].view(-1, 1).repeat(1, 3)
|
| 152 |
+
qbeta_alpha_sum = scatter_add(q[is_sig], object_index) + 1e-9 # * beta[is_sig]
|
| 153 |
+
div_fac = 1 / qbeta_alpha_sum
|
| 154 |
+
div_fac = torch.nan_to_num(div_fac, nan=0)
|
| 155 |
+
x_alpha_mean = torch.mul(x_alpha_sum, div_fac.view(-1, 1).repeat(1, 3))
|
| 156 |
+
x_alpha = use_average_cc_pos * x_alpha_mean + (1 - use_average_cc_pos) * x_alpha
|
| 157 |
+
|
| 158 |
+
beta_alpha = beta[is_sig][index_alpha]
|
| 159 |
+
assert x_alpha.size() == (n_objects, cluster_space_dim)
|
| 160 |
+
assert beta_alpha.size() == (n_objects,)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# Connectivity matrix from hit (row) -> cluster (column)
|
| 164 |
+
# Index to matrix, e.g.:
|
| 165 |
+
# [1, 3, 1, 0] --> [
|
| 166 |
+
# [0, 1, 0, 0],
|
| 167 |
+
# [0, 0, 0, 1],
|
| 168 |
+
# [0, 1, 0, 0],
|
| 169 |
+
# [1, 0, 0, 0]
|
| 170 |
+
# ]
|
| 171 |
+
M = torch.nn.functional.one_hot(cluster_index).long()
|
| 172 |
+
|
| 173 |
+
# Anti-connectivity matrix; be sure not to connect hits to clusters in different events!
|
| 174 |
+
M_inv = get_inter_event_norms_mask(batch, n_clusters_per_event) - M
|
| 175 |
+
|
| 176 |
+
# Throw away noise cluster columns; we never need them
|
| 177 |
+
M = M[:, is_object]
|
| 178 |
+
M_inv = M_inv[:, is_object]
|
| 179 |
+
assert M.size() == (n_hits, n_objects)
|
| 180 |
+
assert M_inv.size() == (n_hits, n_objects)
|
| 181 |
+
|
| 182 |
+
# Calculate all norms
|
| 183 |
+
# Warning: Should not be used without a mask!
|
| 184 |
+
# Contains norms between hits and objects from different events
|
| 185 |
+
# (n_hits, 1, cluster_space_dim) - (1, n_objects, cluster_space_dim)
|
| 186 |
+
# gives (n_hits, n_objects, cluster_space_dim)
|
| 187 |
+
norms = (cluster_space_coords.unsqueeze(1) - x_alpha.unsqueeze(0)).norm(dim=-1)
|
| 188 |
+
assert norms.size() == (n_hits, n_objects)
|
| 189 |
+
L_clusters = torch.tensor(0.0).to(device)
|
| 190 |
+
if frac_combinations != 0:
|
| 191 |
+
L_clusters = L_clusters_calc(
|
| 192 |
+
batch, cluster_space_coords, cluster_index, frac_combinations, q
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# -------
|
| 196 |
+
# Attractive potential term
|
| 197 |
+
# First get all the relevant norms: We only want norms of signal hits
|
| 198 |
+
# w.r.t. the object they belong to, i.e. no noise hits and no noise clusters.
|
| 199 |
+
# First select all norms of all signal hits w.r.t. all objects, mask out later
|
| 200 |
+
|
| 201 |
+
N_k = torch.sum(M, dim=0) # number of hits per object
|
| 202 |
+
norms = torch.sum(
|
| 203 |
+
torch.square(cluster_space_coords.unsqueeze(1) - x_alpha.unsqueeze(0)),
|
| 204 |
+
dim=-1,
|
| 205 |
+
) # take the norm squared
|
| 206 |
+
norms_att = norms[is_sig]
|
| 207 |
+
#att func as in line 159 of object condensation
|
| 208 |
+
|
| 209 |
+
norms_att = torch.log(
|
| 210 |
+
torch.exp(torch.Tensor([1]).to(norms_att.device)) * norms_att / 2 + 1
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
assert norms_att.size() == (n_hits_sig, n_objects)
|
| 214 |
+
|
| 215 |
+
# Now apply the mask to keep only norms of signal hits w.r.t. to the object
|
| 216 |
+
# they belong to
|
| 217 |
+
norms_att *= M[is_sig]
|
| 218 |
+
|
| 219 |
+
# Sum over hits, then sum per event, then divide by n_hits_per_event, then sum over events
|
| 220 |
+
|
| 221 |
+
V_attractive = (q[is_sig]).unsqueeze(-1) * q_alpha.unsqueeze(0) * norms_att
|
| 222 |
+
V_attractive = V_attractive.sum(dim=0) # K objects
|
| 223 |
+
V_attractive = V_attractive.view(-1) / (N_k.view(-1) + 1e-3)
|
| 224 |
+
L_V_attractive = torch.mean(V_attractive)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
norms_rep = torch.relu(1. - torch.sqrt(norms + 1e-6))* M_inv
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# (n_sig_hits, 1) * (1, n_objects) * (n_sig_hits, n_objects)
|
| 231 |
+
V_repulsive = q.unsqueeze(1) * q_alpha.unsqueeze(0) * norms_rep
|
| 232 |
+
|
| 233 |
+
# No need to apply a V = max(0, V); by construction V>=0
|
| 234 |
+
assert V_repulsive.size() == (n_hits, n_objects)
|
| 235 |
+
|
| 236 |
+
# Sum over hits, then sum per event, then divide by n_hits_per_event, then sum up events
|
| 237 |
+
nope = n_objects_per_event - 1
|
| 238 |
+
nope[nope == 0] = 1
|
| 239 |
+
|
| 240 |
+
L_V_repulsive = V_repulsive.sum(dim=0)
|
| 241 |
+
number_of_repulsive_terms_per_object = torch.sum(M_inv, dim=0)
|
| 242 |
+
L_V_repulsive = L_V_repulsive.view(
|
| 243 |
+
-1
|
| 244 |
+
) / number_of_repulsive_terms_per_object.view(-1)
|
| 245 |
+
L_V_repulsive = torch.mean(L_V_repulsive)
|
| 246 |
+
L_V_repulsive2 = L_V_repulsive
|
| 247 |
+
|
| 248 |
+
L_V = (
|
| 249 |
+
L_V_attractive
|
| 250 |
+
+ L_V_repulsive
|
| 251 |
+
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
n_noise_hits_per_event = scatter_count(batch[is_noise])
|
| 258 |
+
n_noise_hits_per_event[n_noise_hits_per_event == 0] = 1
|
| 259 |
+
L_beta_noise = (
|
| 260 |
+
s_B
|
| 261 |
+
* (
|
| 262 |
+
(scatter_add(beta[is_noise], batch[is_noise])) / n_noise_hits_per_event
|
| 263 |
+
).sum()
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# L_beta signal term
|
| 267 |
+
|
| 268 |
+
beta_per_object_c = scatter_add(beta[is_sig], object_index)
|
| 269 |
+
beta_alpha = beta[is_sig][index_alpha]
|
| 270 |
+
# hit_type_mask = (g.ndata["hit_type"]==1)*(g.ndata["particle_number"]>0)
|
| 271 |
+
# beta_alpha_track = beta[is_sig*hit_type_mask]
|
| 272 |
+
L_beta_sig = torch.mean(
|
| 273 |
+
1 - beta_alpha + 1 - torch.clip(beta_per_object_c, 0, 1)
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
L_beta_noise = L_beta_noise / 4
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
L_beta = L_beta_noise + L_beta_sig
|
| 280 |
+
|
| 281 |
+
L_alpha_coordinates = torch.mean(torch.norm(x_alpha_original - x_alpha, p=2, dim=1))
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
L_exp = L_beta
|
| 285 |
+
if (loss_type == "hgcalimplementation") or (loss_type == "vrepweighted") or (loss_type == "baseline"):
|
| 286 |
+
return (
|
| 287 |
+
L_V,
|
| 288 |
+
L_beta,
|
| 289 |
+
L_beta_sig,
|
| 290 |
+
L_beta_noise,
|
| 291 |
+
0,
|
| 292 |
+
0,
|
| 293 |
+
0,
|
| 294 |
+
None,
|
| 295 |
+
None,
|
| 296 |
+
0,
|
| 297 |
+
L_clusters,
|
| 298 |
+
0,
|
| 299 |
+
L_V_attractive,
|
| 300 |
+
L_V_repulsive,
|
| 301 |
+
L_alpha_coordinates,
|
| 302 |
+
L_exp,
|
| 303 |
+
norms_rep,
|
| 304 |
+
norms_att,
|
| 305 |
+
L_V_repulsive2,
|
| 306 |
+
0
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def object_condensation_loss2(
|
| 311 |
+
batch,
|
| 312 |
+
pred,
|
| 313 |
+
pred_2,
|
| 314 |
+
y,
|
| 315 |
+
q_min=0.1,
|
| 316 |
+
use_average_cc_pos=0.0,
|
| 317 |
+
output_dim=4,
|
| 318 |
+
clust_space_norm="none",
|
| 319 |
+
):
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
:param batch:
|
| 323 |
+
:param pred:
|
| 324 |
+
:param y:
|
| 325 |
+
:param return_resolution: If True, it will only output resolution data to plot for regression (only used for evaluation...)
|
| 326 |
+
:param clust_loss_only: If True, it will only add the clustering terms to the loss
|
| 327 |
+
:return:
|
| 328 |
+
"""
|
| 329 |
+
_, S = pred.shape
|
| 330 |
+
|
| 331 |
+
clust_space_dim = 3
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
bj = torch.sigmoid(torch.reshape(pred[:, clust_space_dim], [-1, 1])) # 3: betas
|
| 335 |
+
# print("bj", bj)
|
| 336 |
+
original_coords = batch.ndata["h"][:, 0:clust_space_dim]
|
| 337 |
+
distance_threshold = 0
|
| 338 |
+
energy_correction = pred_2
|
| 339 |
+
xj = pred[:, 0:clust_space_dim] # xj: cluster space coords
|
| 340 |
+
if clust_space_norm == "twonorm":
|
| 341 |
+
xj = torch.nn.functional.normalize(xj, dim=1) # 0, 1, 2: cluster space coords
|
| 342 |
+
elif clust_space_norm == "tanh":
|
| 343 |
+
xj = torch.tanh(xj)
|
| 344 |
+
elif clust_space_norm == "none":
|
| 345 |
+
pass
|
| 346 |
+
else:
|
| 347 |
+
raise NotImplementedError
|
| 348 |
+
|
| 349 |
+
dev = batch.device
|
| 350 |
+
clustering_index_l = batch.ndata["particle_number"]
|
| 351 |
+
|
| 352 |
+
len_batch = len(batch.batch_num_nodes())
|
| 353 |
+
batch_numbers = torch.repeat_interleave(
|
| 354 |
+
torch.arange(0, len_batch).to(dev), batch.batch_num_nodes()
|
| 355 |
+
).to(dev)
|
| 356 |
+
|
| 357 |
+
a = calc_LV_Lbeta(
|
| 358 |
+
original_coords,
|
| 359 |
+
batch,
|
| 360 |
+
y,
|
| 361 |
+
distance_threshold,
|
| 362 |
+
energy_correction,
|
| 363 |
+
beta=bj.view(-1),
|
| 364 |
+
cluster_space_coords=xj, # Predicted by model
|
| 365 |
+
cluster_index_per_event=clustering_index_l.view(
|
| 366 |
+
-1
|
| 367 |
+
).long(), # Truth hit->cluster index
|
| 368 |
+
batch=batch_numbers.long(),
|
| 369 |
+
qmin=q_min,
|
| 370 |
+
use_average_cc_pos=use_average_cc_pos,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
loss = 1 * a[0] + a[1]
|
| 375 |
+
|
| 376 |
+
return loss, a
|
| 377 |
+
|
| 378 |
+
def formatted_loss_components_string(components: dict) -> str:
|
| 379 |
+
"""
|
| 380 |
+
Formats the components returned by calc_LV_Lbeta
|
| 381 |
+
"""
|
| 382 |
+
total_loss = components["L_V"] + components["L_beta"]
|
| 383 |
+
fractions = {k: v / total_loss for k, v in components.items()}
|
| 384 |
+
fkey = lambda key: f"{components[key]:+.4f} ({100.*fractions[key]:.1f}%)"
|
| 385 |
+
s = (
|
| 386 |
+
" L_V = {L_V}"
|
| 387 |
+
"\n L_V_attractive = {L_V_attractive}"
|
| 388 |
+
"\n L_V_repulsive = {L_V_repulsive}"
|
| 389 |
+
"\n L_beta = {L_beta}"
|
| 390 |
+
"\n L_beta_noise = {L_beta_noise}"
|
| 391 |
+
"\n L_beta_sig = {L_beta_sig}".format(
|
| 392 |
+
L=total_loss, **{k: fkey(k) for k in components}
|
| 393 |
+
)
|
| 394 |
+
)
|
| 395 |
+
if "L_beta_norms_term" in components:
|
| 396 |
+
s += (
|
| 397 |
+
"\n L_beta_norms_term = {L_beta_norms_term}"
|
| 398 |
+
"\n L_beta_logbeta_term = {L_beta_logbeta_term}".format(
|
| 399 |
+
**{k: fkey(k) for k in components}
|
| 400 |
+
)
|
| 401 |
+
)
|
| 402 |
+
if "L_noise_filter" in components:
|
| 403 |
+
s += f'\n L_noise_filter = {fkey("L_noise_filter")}'
|
| 404 |
+
return s
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def huber(d, delta):
|
| 408 |
+
"""
|
| 409 |
+
See: https://en.wikipedia.org/wiki/Huber_loss#Definition
|
| 410 |
+
Multiplied by 2 w.r.t Wikipedia version (aligning with Jan's definition)
|
| 411 |
+
"""
|
| 412 |
+
return torch.where(
|
| 413 |
+
torch.abs(d) <= delta, d**2, 2.0 * delta * (torch.abs(d) - delta)
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def batch_cluster_indices(
|
| 418 |
+
cluster_id: torch.Tensor, batch: torch.Tensor
|
| 419 |
+
) -> Tuple[torch.LongTensor, torch.LongTensor]:
|
| 420 |
+
"""
|
| 421 |
+
Turns cluster indices per event to an index in the whole batch
|
| 422 |
+
Example:
|
| 423 |
+
cluster_id = torch.LongTensor([0, 0, 1, 1, 2, 0, 0, 1, 1, 1, 0, 0, 1])
|
| 424 |
+
batch = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2])
|
| 425 |
+
-->
|
| 426 |
+
offset = torch.LongTensor([0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 5, 5, 5])
|
| 427 |
+
output = torch.LongTensor([0, 0, 1, 1, 2, 3, 3, 4, 4, 4, 5, 5, 6])
|
| 428 |
+
"""
|
| 429 |
+
device = cluster_id.device
|
| 430 |
+
assert cluster_id.device == batch.device
|
| 431 |
+
# Count the number of clusters per entry in the batch
|
| 432 |
+
n_clusters_per_event = scatter_max(cluster_id, batch, dim=-1)[0] + 1
|
| 433 |
+
# Offsets are then a cumulative sum
|
| 434 |
+
offset_values_nozero = n_clusters_per_event[:-1].cumsum(dim=-1)
|
| 435 |
+
# Prefix a zero
|
| 436 |
+
offset_values = torch.cat((torch.zeros(1, device=device), offset_values_nozero))
|
| 437 |
+
# Fill it per hit
|
| 438 |
+
offset = torch.gather(offset_values, 0, batch).long()
|
| 439 |
+
return offset + cluster_id, n_clusters_per_event
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def get_clustering(betas: torch.Tensor, X: torch.Tensor, tbeta=0.1, td=1.0):
|
| 443 |
+
"""
|
| 444 |
+
Returns a clustering of hits -> cluster_index, based on the GravNet model
|
| 445 |
+
output (predicted betas and cluster space coordinates) and the clustering
|
| 446 |
+
parameters tbeta and td.
|
| 447 |
+
Takes torch.Tensors as input.
|
| 448 |
+
"""
|
| 449 |
+
n_points = betas.size(0)
|
| 450 |
+
select_condpoints = betas > tbeta
|
| 451 |
+
# Get indices passing the threshold
|
| 452 |
+
indices_condpoints = select_condpoints.nonzero()
|
| 453 |
+
# Order them by decreasing beta value
|
| 454 |
+
indices_condpoints = indices_condpoints[(-betas[select_condpoints]).argsort()]
|
| 455 |
+
# Assign points to condensation points
|
| 456 |
+
# Only assign previously unassigned points (no overwriting)
|
| 457 |
+
# Points unassigned at the end are bkg (-1)
|
| 458 |
+
unassigned = torch.arange(n_points)
|
| 459 |
+
clustering = -1 * torch.ones(n_points, dtype=torch.long).to(betas.device)
|
| 460 |
+
for index_condpoint in indices_condpoints:
|
| 461 |
+
d = torch.norm(X[unassigned] - X[index_condpoint][0], dim=-1)
|
| 462 |
+
assigned_to_this_condpoint = unassigned[d < td]
|
| 463 |
+
clustering[assigned_to_this_condpoint] = index_condpoint[0]
|
| 464 |
+
unassigned = unassigned[~(d < td)]
|
| 465 |
+
return clustering
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def scatter_count(input: torch.Tensor):
|
| 469 |
+
"""
|
| 470 |
+
Returns ordered counts over an index array
|
| 471 |
+
Example:
|
| 472 |
+
>>> scatter_count(torch.Tensor([0, 0, 0, 1, 1, 2, 2])) # input
|
| 473 |
+
>>> [3, 2, 2]
|
| 474 |
+
Index assumptions work like in torch_scatter, so:
|
| 475 |
+
>>> scatter_count(torch.Tensor([1, 1, 1, 2, 2, 4, 4]))
|
| 476 |
+
>>> tensor([0, 3, 2, 0, 2])
|
| 477 |
+
"""
|
| 478 |
+
return scatter_add(torch.ones_like(input, dtype=torch.long), input.long())
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def scatter_counts_to_indices(input: torch.LongTensor) -> torch.LongTensor:
|
| 482 |
+
"""
|
| 483 |
+
Converts counts to indices. This is the inverse operation of scatter_count
|
| 484 |
+
Example:
|
| 485 |
+
input: [3, 2, 2]
|
| 486 |
+
output: [0, 0, 0, 1, 1, 2, 2]
|
| 487 |
+
"""
|
| 488 |
+
return torch.repeat_interleave(
|
| 489 |
+
torch.arange(input.size(0), device=input.device), input
|
| 490 |
+
).long()
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def get_inter_event_norms_mask(
|
| 494 |
+
batch: torch.LongTensor, nclusters_per_event: torch.LongTensor
|
| 495 |
+
):
|
| 496 |
+
"""
|
| 497 |
+
Creates mask of (nhits x nclusters) that is only 1 if hit i is in the same event as cluster j
|
| 498 |
+
Example:
|
| 499 |
+
cluster_id_per_event = torch.LongTensor([0, 0, 1, 1, 2, 0, 0, 1, 1, 1, 0, 0, 1])
|
| 500 |
+
batch = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2])
|
| 501 |
+
Should return:
|
| 502 |
+
torch.LongTensor([
|
| 503 |
+
[1, 1, 1, 0, 0, 0, 0],
|
| 504 |
+
[1, 1, 1, 0, 0, 0, 0],
|
| 505 |
+
[1, 1, 1, 0, 0, 0, 0],
|
| 506 |
+
[1, 1, 1, 0, 0, 0, 0],
|
| 507 |
+
[1, 1, 1, 0, 0, 0, 0],
|
| 508 |
+
[0, 0, 0, 1, 1, 0, 0],
|
| 509 |
+
[0, 0, 0, 1, 1, 0, 0],
|
| 510 |
+
[0, 0, 0, 1, 1, 0, 0],
|
| 511 |
+
[0, 0, 0, 1, 1, 0, 0],
|
| 512 |
+
[0, 0, 0, 1, 1, 0, 0],
|
| 513 |
+
[0, 0, 0, 0, 0, 1, 1],
|
| 514 |
+
[0, 0, 0, 0, 0, 1, 1],
|
| 515 |
+
[0, 0, 0, 0, 0, 1, 1],
|
| 516 |
+
])
|
| 517 |
+
"""
|
| 518 |
+
device = batch.device
|
| 519 |
+
# Following the example:
|
| 520 |
+
# Expand batch to the following (nhits x nevents) matrix (little hacky, boolean mask -> long):
|
| 521 |
+
# [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 522 |
+
# [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
| 523 |
+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]]
|
| 524 |
+
batch_expanded_as_ones = (
|
| 525 |
+
batch
|
| 526 |
+
== torch.arange(batch.max() + 1, dtype=torch.long, device=device).unsqueeze(-1)
|
| 527 |
+
).long()
|
| 528 |
+
# Then repeat_interleave it to expand it to nclusters rows, and transpose to get (nhits x nclusters)
|
| 529 |
+
return batch_expanded_as_ones.repeat_interleave(nclusters_per_event, dim=0).T
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def isin(ar1, ar2):
|
| 533 |
+
"""To be replaced by torch.isin for newer releases of torch"""
|
| 534 |
+
return (ar1[..., None] == ar2).any(-1)
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def L_clusters_calc(batch, cluster_space_coords, cluster_index, frac_combinations, q):
|
| 538 |
+
number_of_pairs = 0
|
| 539 |
+
for batch_id in batch.unique():
|
| 540 |
+
# do all possible pairs...
|
| 541 |
+
bmask = batch == batch_id
|
| 542 |
+
clust_space_filt = cluster_space_coords[bmask]
|
| 543 |
+
pos_pairs_all = []
|
| 544 |
+
neg_pairs_all = []
|
| 545 |
+
if len(cluster_index[bmask].unique()) <= 1:
|
| 546 |
+
continue
|
| 547 |
+
L_clusters = torch.tensor(0.0).to(q.device)
|
| 548 |
+
for cluster in cluster_index[bmask].unique():
|
| 549 |
+
coords_pos = clust_space_filt[cluster_index[bmask] == cluster]
|
| 550 |
+
coords_neg = clust_space_filt[cluster_index[bmask] != cluster]
|
| 551 |
+
if len(coords_neg) == 0:
|
| 552 |
+
continue
|
| 553 |
+
clust_idx = cluster_index[bmask] == cluster
|
| 554 |
+
# all_ones = torch.ones_like((clust_idx, clust_idx))
|
| 555 |
+
# pos_pairs = [[i, j] for i in range(len(coords_pos)) for j in range (len(coords_pos)) if i < j]
|
| 556 |
+
total_num = (len(coords_pos) ** 2) / 2
|
| 557 |
+
num = int(frac_combinations * total_num)
|
| 558 |
+
pos_pairs = []
|
| 559 |
+
for i in range(num):
|
| 560 |
+
pos_pairs.append(
|
| 561 |
+
[
|
| 562 |
+
np.random.randint(len(coords_pos)),
|
| 563 |
+
np.random.randint(len(coords_pos)),
|
| 564 |
+
]
|
| 565 |
+
)
|
| 566 |
+
neg_pairs = []
|
| 567 |
+
for i in range(len(pos_pairs)):
|
| 568 |
+
neg_pairs.append(
|
| 569 |
+
[
|
| 570 |
+
np.random.randint(len(coords_pos)),
|
| 571 |
+
np.random.randint(len(coords_neg)),
|
| 572 |
+
]
|
| 573 |
+
)
|
| 574 |
+
pos_pairs_all += pos_pairs
|
| 575 |
+
neg_pairs_all += neg_pairs
|
| 576 |
+
pos_pairs = torch.tensor(pos_pairs_all)
|
| 577 |
+
neg_pairs = torch.tensor(neg_pairs_all)
|
| 578 |
+
assert pos_pairs.shape == neg_pairs.shape
|
| 579 |
+
if len(pos_pairs) == 0:
|
| 580 |
+
continue
|
| 581 |
+
cluster_space_coords_filtered = cluster_space_coords[bmask]
|
| 582 |
+
qs_filtered = q[bmask]
|
| 583 |
+
pos_norms = (
|
| 584 |
+
cluster_space_coords_filtered[pos_pairs[:, 0]]
|
| 585 |
+
- cluster_space_coords_filtered[pos_pairs[:, 1]]
|
| 586 |
+
).norm(dim=-1)
|
| 587 |
+
|
| 588 |
+
neg_norms = (
|
| 589 |
+
cluster_space_coords_filtered[neg_pairs[:, 0]]
|
| 590 |
+
- cluster_space_coords_filtered[neg_pairs[:, 1]]
|
| 591 |
+
).norm(dim=-1)
|
| 592 |
+
q_pos = qs_filtered[pos_pairs[:, 0]]
|
| 593 |
+
q_neg = qs_filtered[neg_pairs[:, 0]]
|
| 594 |
+
q_s = torch.cat([q_pos, q_neg])
|
| 595 |
+
norms_pos = torch.cat([pos_norms, neg_norms])
|
| 596 |
+
ys = torch.cat([torch.ones_like(pos_norms), -torch.ones_like(neg_norms)])
|
| 597 |
+
L_clusters += torch.sum(
|
| 598 |
+
q_s * torch.nn.HingeEmbeddingLoss(reduce=None)(norms_pos, ys)
|
| 599 |
+
)
|
| 600 |
+
number_of_pairs += norms_pos.shape[0]
|
| 601 |
+
if number_of_pairs > 0:
|
| 602 |
+
L_clusters = L_clusters / number_of_pairs
|
| 603 |
+
|
| 604 |
+
return L_clusters
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
|
src/layers/regression/loss_regression.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def obtain_PID_charged(dic,pid_true_matched, pids_charged, args, pid_conversion_dict):
|
| 8 |
+
charged_PID_pred = dic["charged_PID_pred"]
|
| 9 |
+
charged_PID_true = np.array(pid_true_matched)[dic["charged_idx"].cpu().tolist()]
|
| 10 |
+
# one-hot encoded
|
| 11 |
+
charged_PID_true_onehot = torch.zeros(
|
| 12 |
+
len(charged_PID_true), len(pids_charged)
|
| 13 |
+
).to(charged_PID_pred.device)
|
| 14 |
+
mask_charged = torch.ones(len(charged_PID_true))
|
| 15 |
+
pids_charged_arr = np.array(pids_charged)
|
| 16 |
+
for i, pid in enumerate(charged_PID_true):
|
| 17 |
+
if pid not in pid_conversion_dict:
|
| 18 |
+
print("Unknown PID", pid)
|
| 19 |
+
true_idx = pid_conversion_dict.get(pid, 3)
|
| 20 |
+
col = np.where(pids_charged_arr == true_idx)[0]
|
| 21 |
+
if len(col) == 0:
|
| 22 |
+
mask_charged[i] = 0
|
| 23 |
+
else:
|
| 24 |
+
charged_PID_true_onehot[i, col[0]] = 1
|
| 25 |
+
return charged_PID_pred, charged_PID_true_onehot, mask_charged
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def obtain_PID_neutral(dic,pid_true_matched,pids_neutral, args, pid_conversion_dict):
|
| 33 |
+
neutral_PID_pred = dic["neutral_PID_pred"]
|
| 34 |
+
neutral_idx = dic["neutrals_idx"]
|
| 35 |
+
neutral_PID_true = np.array(pid_true_matched)[neutral_idx.cpu()]
|
| 36 |
+
if type(neutral_PID_true) == np.float64:
|
| 37 |
+
neutral_PID_true = [neutral_PID_true]
|
| 38 |
+
# One-hot encoded
|
| 39 |
+
neutral_PID_true_onehot = torch.zeros(
|
| 40 |
+
len(neutral_PID_true), len(pids_neutral)
|
| 41 |
+
).to(neutral_PID_pred.device)
|
| 42 |
+
mask_neutral = torch.ones(len(neutral_PID_true))
|
| 43 |
+
|
| 44 |
+
# convert from true PID to int list PID (4-class encoding)
|
| 45 |
+
pids_neutral_arr = np.array(pids_neutral)
|
| 46 |
+
for i, pid in enumerate(neutral_PID_true):
|
| 47 |
+
if pid not in pid_conversion_dict:
|
| 48 |
+
print("Unknown PID", pid)
|
| 49 |
+
true_idx = pid_conversion_dict.get(pid, 3)
|
| 50 |
+
col = np.where(pids_neutral_arr == true_idx)[0]
|
| 51 |
+
if len(col) == 0:
|
| 52 |
+
mask_neutral[i] = 0
|
| 53 |
+
else:
|
| 54 |
+
neutral_PID_true_onehot[i, col[0]] = 1
|
| 55 |
+
neutral_PID_true_onehot = neutral_PID_true_onehot.to(neutral_idx.device)
|
| 56 |
+
return neutral_PID_pred, neutral_PID_true_onehot, mask_neutral
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
src/layers/shower_dataframe.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DataFrame construction and shower-level helpers for particle-flow reconstruction."""
|
| 2 |
+
import torch
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from torch_scatter import scatter_add, scatter_mean, scatter_max
|
| 5 |
+
|
| 6 |
+
from src.layers.clustering import remove_labels_of_double_showers
|
| 7 |
+
from src.layers.shower_matching import obtain_intersection_values
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# ---------------------------------------------------------------------------
|
| 11 |
+
# Small tensor helpers
|
| 12 |
+
# ---------------------------------------------------------------------------
|
| 13 |
+
|
| 14 |
+
def nan_like(t):
|
| 15 |
+
return torch.zeros_like(t) * torch.nan
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def nan_tensor(*size, device):
|
| 19 |
+
return torch.zeros(*size, device=device) * torch.nan
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _window(tensor, start, count):
|
| 23 |
+
return tensor[start : start + count]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _compute_pandora_momentum(labels, g):
|
| 27 |
+
"""Scatter-mean the pandora momentum/reference-point node features per cluster.
|
| 28 |
+
|
| 29 |
+
Returns (pxyz, ref_pt, pandora_pid, calc_pandora_momentum). All three
|
| 30 |
+
tensor outputs are None when the graph does not carry 'pandora_momentum'.
|
| 31 |
+
"""
|
| 32 |
+
calc_pandora_momentum = "pandora_momentum" in g.ndata
|
| 33 |
+
if not calc_pandora_momentum:
|
| 34 |
+
return None, None, None, False
|
| 35 |
+
px = scatter_mean(g.ndata["pandora_momentum"][:, 0], labels)
|
| 36 |
+
py = scatter_mean(g.ndata["pandora_momentum"][:, 1], labels)
|
| 37 |
+
pz = scatter_mean(g.ndata["pandora_momentum"][:, 2], labels)
|
| 38 |
+
ref_pt_px = scatter_mean(g.ndata["pandora_reference_point"][:, 0], labels)
|
| 39 |
+
ref_pt_py = scatter_mean(g.ndata["pandora_reference_point"][:, 1], labels)
|
| 40 |
+
ref_pt_pz = scatter_mean(g.ndata["pandora_reference_point"][:, 2], labels)
|
| 41 |
+
pandora_pid = scatter_mean(g.ndata["pandora_pid"], labels)
|
| 42 |
+
ref_pt = torch.stack((ref_pt_px, ref_pt_py, ref_pt_pz), dim=1)
|
| 43 |
+
pxyz = torch.stack((px, py, pz), dim=1)
|
| 44 |
+
return pxyz, ref_pt, pandora_pid, True
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# Per-shower correction
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
def get_correction_per_shower(labels, dic):
|
| 52 |
+
unique_labels = torch.unique(labels)
|
| 53 |
+
list_corr = []
|
| 54 |
+
for ii, pred_label in enumerate(unique_labels):
|
| 55 |
+
if ii == 0:
|
| 56 |
+
if pred_label != 0:
|
| 57 |
+
list_corr.append(dic["graph"].ndata["correction"][0].view(-1) * 0)
|
| 58 |
+
mask = labels == pred_label
|
| 59 |
+
corrections_E_label = dic["graph"].ndata["correction"][mask]
|
| 60 |
+
betas_label_indmax = torch.argmax(dic["graph"].ndata["beta"][mask])
|
| 61 |
+
list_corr.append(corrections_E_label[betas_label_indmax].view(-1))
|
| 62 |
+
corrections = torch.cat(list_corr, dim=0)
|
| 63 |
+
return corrections
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ---------------------------------------------------------------------------
|
| 67 |
+
# Track–cluster distance helpers
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
|
| 70 |
+
def distance_to_true_cluster_of_track(dic, labels):
|
| 71 |
+
g = dic["graph"]
|
| 72 |
+
mask_hit_type_t2 = g.ndata["hit_type"] == 1
|
| 73 |
+
if torch.sum(labels.unique() == 0) == 0:
|
| 74 |
+
distances = torch.zeros(len(labels.unique()) + 1).float().to(labels.device)
|
| 75 |
+
number_of_tracks = torch.zeros(len(labels.unique()) + 1).int()
|
| 76 |
+
else:
|
| 77 |
+
distances = torch.zeros(len(labels.unique())).float().to(labels.device)
|
| 78 |
+
number_of_tracks = torch.zeros(len(labels.unique())).int()
|
| 79 |
+
for i, label in enumerate(labels.unique()):
|
| 80 |
+
mask_labels_i = labels == label
|
| 81 |
+
mask = mask_labels_i * mask_hit_type_t2
|
| 82 |
+
if mask.sum() == 0:
|
| 83 |
+
continue
|
| 84 |
+
pos_track = g.ndata["pos_hits_xyz"][mask][0]
|
| 85 |
+
if pos_track.shape[0] == 0:
|
| 86 |
+
continue
|
| 87 |
+
true_part_idx_track = g.ndata["particle_number"][mask_labels_i * mask_hit_type_t2][0].int()
|
| 88 |
+
mask_labels_i_true = g.ndata["particle_number"] == true_part_idx_track
|
| 89 |
+
mean_pos_cluster_true = torch.mean(
|
| 90 |
+
g.ndata["pos_hits_xyz"][mask_labels_i_true], dim=0
|
| 91 |
+
)
|
| 92 |
+
number_of_tracks[label] = torch.sum(mask_labels_i_true * mask_hit_type_t2)
|
| 93 |
+
distances[label] = torch.norm(mean_pos_cluster_true - pos_track) / 3300
|
| 94 |
+
return distances, number_of_tracks
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def distance_to_cluster_track(dic, is_track_in_MC):
|
| 98 |
+
g = dic["graph"]
|
| 99 |
+
mask_hit_type_t1 = g.ndata["hit_type"] == 2
|
| 100 |
+
mask_hit_type_t2 = g.ndata["hit_type"] == 1
|
| 101 |
+
pos_track = g.ndata["pos_hits_xyz"][mask_hit_type_t2]
|
| 102 |
+
particle_track = g.ndata["particle_number"][mask_hit_type_t2]
|
| 103 |
+
if len(particle_track) > 0:
|
| 104 |
+
mean_pos_cluster_all = []
|
| 105 |
+
for i in particle_track:
|
| 106 |
+
if i == 0:
|
| 107 |
+
mean_pos_cluster_all.append(torch.zeros((1, 3)).view(-1, 3).to(particle_track.device))
|
| 108 |
+
else:
|
| 109 |
+
mask_labels_i = g.ndata["particle_number"] == i
|
| 110 |
+
mean_pos_cluster = torch.mean(g.ndata["pos_hits_xyz"][mask_labels_i * mask_hit_type_t1], dim=0)
|
| 111 |
+
mean_pos_cluster_all.append(mean_pos_cluster.view(-1, 3))
|
| 112 |
+
mean_pos_cluster_all = torch.cat(mean_pos_cluster_all, dim=0)
|
| 113 |
+
distance_track_cluster = torch.norm(mean_pos_cluster_all - pos_track, dim=1) / 1000
|
| 114 |
+
if len(particle_track) > len(torch.unique(particle_track)):
|
| 115 |
+
distance_track_cluster_unique = []
|
| 116 |
+
for i in torch.unique(particle_track):
|
| 117 |
+
mask_tracks = particle_track == i
|
| 118 |
+
distance_track_cluster_unique.append(torch.min(distance_track_cluster[mask_tracks]).view(-1))
|
| 119 |
+
distance_track_cluster_unique = torch.cat(distance_track_cluster_unique, dim=0)
|
| 120 |
+
unique_particle_track = torch.unique(particle_track)
|
| 121 |
+
else:
|
| 122 |
+
distance_track_cluster_unique = distance_track_cluster
|
| 123 |
+
unique_particle_track = particle_track
|
| 124 |
+
distance_to_cluster_all = is_track_in_MC.clone().float()
|
| 125 |
+
distance_to_cluster_all[unique_particle_track.long()] = distance_track_cluster_unique
|
| 126 |
+
return distance_to_cluster_all
|
| 127 |
+
else:
|
| 128 |
+
return is_track_in_MC.clone().float()
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# ---------------------------------------------------------------------------
|
| 132 |
+
# Main DataFrame builder
|
| 133 |
+
# ---------------------------------------------------------------------------
|
| 134 |
+
|
| 135 |
+
def generate_showers_data_frame(
|
| 136 |
+
labels,
|
| 137 |
+
dic,
|
| 138 |
+
shower_p_unique,
|
| 139 |
+
particle_ids,
|
| 140 |
+
row_ind,
|
| 141 |
+
col_ind,
|
| 142 |
+
i_m_w,
|
| 143 |
+
pandora=False,
|
| 144 |
+
e_corr=None,
|
| 145 |
+
number_of_showers_total=None,
|
| 146 |
+
step=0,
|
| 147 |
+
number_in_batch=0,
|
| 148 |
+
ec_x=None,
|
| 149 |
+
pred_pos=None,
|
| 150 |
+
pred_pid=None,
|
| 151 |
+
pred_ref_pt=None,
|
| 152 |
+
number_of_fake_showers_total=None,
|
| 153 |
+
number_of_fakes=None,
|
| 154 |
+
extra_features=None,
|
| 155 |
+
labels_clusters_removed_tracks=None,
|
| 156 |
+
):
|
| 157 |
+
e_pred_showers = scatter_add(dic["graph"].ndata["e_hits"].view(-1), labels)
|
| 158 |
+
e_pred_showers_ecal = scatter_add(1 * (dic["graph"].ndata["hit_type"].view(-1) == 2), labels)
|
| 159 |
+
e_pred_showers_hcal = scatter_add(1 * (dic["graph"].ndata["hit_type"].view(-1) == 3), labels)
|
| 160 |
+
if not pandora:
|
| 161 |
+
removed_tracks = scatter_add(1 * labels_clusters_removed_tracks, labels)
|
| 162 |
+
if pandora:
|
| 163 |
+
e_pred_showers_cali = scatter_mean(
|
| 164 |
+
dic["graph"].ndata["pandora_pfo_energy"].view(-1), labels
|
| 165 |
+
)
|
| 166 |
+
e_pred_showers_pfo = scatter_mean(
|
| 167 |
+
dic["graph"].ndata["pandora_pfo_energy"].view(-1), labels
|
| 168 |
+
)
|
| 169 |
+
pxyz_pred_pfo, ref_pt_pred_pfo, pandora_pid, calc_pandora_momentum = \
|
| 170 |
+
_compute_pandora_momentum(labels, dic["graph"])
|
| 171 |
+
else:
|
| 172 |
+
if e_corr is None:
|
| 173 |
+
corrections_per_shower = get_correction_per_shower(labels, dic)
|
| 174 |
+
e_pred_showers_cali = e_pred_showers * corrections_per_shower
|
| 175 |
+
else:
|
| 176 |
+
corrections_per_shower = e_corr.view(-1)
|
| 177 |
+
if number_of_fakes > 0:
|
| 178 |
+
corrections_per_shower_fakes = corrections_per_shower[-number_of_fakes:]
|
| 179 |
+
corrections_per_shower = corrections_per_shower[:-number_of_fakes]
|
| 180 |
+
|
| 181 |
+
e_reco_showers = scatter_add(
|
| 182 |
+
dic["graph"].ndata["e_hits"].view(-1),
|
| 183 |
+
dic["graph"].ndata["particle_number"].long(),
|
| 184 |
+
)
|
| 185 |
+
e_label_showers = scatter_max(
|
| 186 |
+
labels.view(-1),
|
| 187 |
+
dic["graph"].ndata["particle_number"].long(),
|
| 188 |
+
)[0]
|
| 189 |
+
is_track_in_MC = scatter_add(
|
| 190 |
+
1 * (dic["graph"].ndata["hit_type"].view(-1) == 1),
|
| 191 |
+
dic["graph"].ndata["particle_number"].long(),
|
| 192 |
+
)
|
| 193 |
+
track_chi = scatter_add(
|
| 194 |
+
1 * (dic["graph"].ndata["chi_squared_tracks"].view(-1) == 1),
|
| 195 |
+
dic["graph"].ndata["particle_number"].long(),
|
| 196 |
+
)
|
| 197 |
+
distance_to_cluster_all = distance_to_cluster_track(dic, is_track_in_MC)
|
| 198 |
+
distances, number_of_tracks = distance_to_true_cluster_of_track(dic, labels)
|
| 199 |
+
|
| 200 |
+
row_ind = torch.Tensor(row_ind).to(e_pred_showers.device).long()
|
| 201 |
+
col_ind = torch.Tensor(col_ind).to(e_pred_showers.device).long()
|
| 202 |
+
|
| 203 |
+
if torch.sum(particle_ids == 0) > 0:
|
| 204 |
+
row_ind_ = row_ind - 1
|
| 205 |
+
else:
|
| 206 |
+
row_ind_ = row_ind
|
| 207 |
+
|
| 208 |
+
pred_showers = shower_p_unique
|
| 209 |
+
energy_t = (
|
| 210 |
+
dic["part_true"].E_corrected.view(-1).to(e_pred_showers.device)
|
| 211 |
+
).float()
|
| 212 |
+
gen_status = (
|
| 213 |
+
dic["part_true"].gen_status.view(-1).to(e_pred_showers.device)
|
| 214 |
+
).float()
|
| 215 |
+
vertex = dic["part_true"].vertex.to(e_pred_showers.device)
|
| 216 |
+
pos_t = dic["part_true"].coord.to(e_pred_showers.device)
|
| 217 |
+
pid_t = dic["part_true"].pid.to(e_pred_showers.device)
|
| 218 |
+
if not pandora:
|
| 219 |
+
labels = remove_labels_of_double_showers(labels, dic["graph"])
|
| 220 |
+
is_track_per_shower = scatter_add(1 * (dic["graph"].ndata["hit_type"] == 1), labels).int()
|
| 221 |
+
is_track = torch.zeros(energy_t.shape).to(e_pred_showers.device)
|
| 222 |
+
|
| 223 |
+
index_matches = col_ind + 1
|
| 224 |
+
index_matches = index_matches.to(e_pred_showers.device).long()
|
| 225 |
+
|
| 226 |
+
dev = e_pred_showers.device
|
| 227 |
+
matched_es = nan_like(energy_t)
|
| 228 |
+
matched_ECAL = nan_like(energy_t)
|
| 229 |
+
matched_HCAL = nan_like(energy_t)
|
| 230 |
+
matched_positions = nan_tensor(energy_t.shape[0], 3, device=dev)
|
| 231 |
+
matched_ref_pt = nan_tensor(energy_t.shape[0], 3, device=dev)
|
| 232 |
+
matched_pid = nan_like(energy_t).long()
|
| 233 |
+
matched_positions_pfo = nan_tensor(energy_t.shape[0], 3, device=dev)
|
| 234 |
+
matched_pandora_pid = nan_tensor(energy_t.shape[0], device=dev)
|
| 235 |
+
matched_ref_pts_pfo = nan_tensor(energy_t.shape[0], 3, device=dev)
|
| 236 |
+
matched_extra_features = torch.zeros((energy_t.shape[0], 7)) * torch.nan
|
| 237 |
+
|
| 238 |
+
matched_es[row_ind_] = e_pred_showers[index_matches]
|
| 239 |
+
matched_ECAL[row_ind_] = 1.0 * e_pred_showers_ecal[index_matches]
|
| 240 |
+
matched_HCAL[row_ind_] = 1.0 * e_pred_showers_hcal[index_matches]
|
| 241 |
+
|
| 242 |
+
if pandora:
|
| 243 |
+
matched_es_cali = matched_es.clone()
|
| 244 |
+
matched_es_cali[row_ind_] = e_pred_showers_cali[index_matches]
|
| 245 |
+
matched_es_cali_pfo = matched_es.clone()
|
| 246 |
+
matched_es_cali_pfo[row_ind_] = e_pred_showers_pfo[index_matches]
|
| 247 |
+
matched_pandora_pid[row_ind_] = pandora_pid[index_matches]
|
| 248 |
+
if calc_pandora_momentum:
|
| 249 |
+
matched_positions_pfo[row_ind_] = pxyz_pred_pfo[index_matches]
|
| 250 |
+
matched_ref_pts_pfo[row_ind_] = ref_pt_pred_pfo[index_matches]
|
| 251 |
+
is_track[row_ind_] = is_track_per_shower[index_matches].float()
|
| 252 |
+
else:
|
| 253 |
+
if e_corr is None:
|
| 254 |
+
matched_es_cali = matched_es.clone()
|
| 255 |
+
matched_es_cali[row_ind_] = e_pred_showers_cali[index_matches]
|
| 256 |
+
calibration_per_shower = matched_es.clone()
|
| 257 |
+
calibration_per_shower[row_ind_] = corrections_per_shower[index_matches]
|
| 258 |
+
cluster_removed_tracks = matched_es.clone()
|
| 259 |
+
else:
|
| 260 |
+
matched_es_cali = matched_es.clone()
|
| 261 |
+
number_of_showers = e_pred_showers[index_matches].shape[0]
|
| 262 |
+
matched_es_cali[row_ind_] = _window(
|
| 263 |
+
corrections_per_shower, number_of_showers_total, number_of_showers
|
| 264 |
+
)
|
| 265 |
+
cluster_removed_tracks = matched_es.clone()
|
| 266 |
+
cluster_removed_tracks[row_ind_] = 1.0 * removed_tracks[index_matches]
|
| 267 |
+
|
| 268 |
+
if pred_pos is not None:
|
| 269 |
+
matched_positions[row_ind_] = _window(pred_pos, number_of_showers_total, number_of_showers)
|
| 270 |
+
matched_ref_pt[row_ind_] = _window(pred_ref_pt, number_of_showers_total, number_of_showers)
|
| 271 |
+
matched_pid[row_ind_] = _window(pred_pid, number_of_showers_total, number_of_showers)
|
| 272 |
+
if not pandora:
|
| 273 |
+
matched_extra_features[row_ind_] = torch.tensor(
|
| 274 |
+
_window(extra_features, number_of_showers_total, number_of_showers)
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
calibration_per_shower = matched_es.clone()
|
| 278 |
+
calibration_per_shower[row_ind_] = _window(
|
| 279 |
+
corrections_per_shower, number_of_showers_total, number_of_showers
|
| 280 |
+
)
|
| 281 |
+
number_of_showers_total = number_of_showers_total + number_of_showers
|
| 282 |
+
is_track[row_ind_] = is_track_per_shower[index_matches].float()
|
| 283 |
+
|
| 284 |
+
# match the tracks to the particle
|
| 285 |
+
dic["graph"].ndata["particle_number_u"] = dic["graph"].ndata["particle_number"].clone()
|
| 286 |
+
dic["graph"].ndata["particle_number_u"][dic["graph"].ndata["particle_number_u"] == 0] = 100
|
| 287 |
+
tracks_label = scatter_max(
|
| 288 |
+
(dic["graph"].ndata["hit_type"] == 1) * (dic["graph"].ndata["particle_number_u"]), labels
|
| 289 |
+
)[0].int()
|
| 290 |
+
tracks_label = tracks_label - 1
|
| 291 |
+
tracks_label[tracks_label < 0] = 0
|
| 292 |
+
matched_es_tracks = nan_like(energy_t)
|
| 293 |
+
matched_es_tracks_1 = nan_like(energy_t)
|
| 294 |
+
matched_es_tracks[row_ind_] = row_ind_.float()
|
| 295 |
+
matched_es_tracks_1[row_ind_] = tracks_label[index_matches].float()
|
| 296 |
+
matched_es_tracks_1 = 1.0 * (matched_es_tracks == matched_es_tracks_1)
|
| 297 |
+
matched_es_tracks_1 = matched_es_tracks_1 * is_track
|
| 298 |
+
|
| 299 |
+
intersection_E = nan_like(energy_t)
|
| 300 |
+
if len(col_ind) > 0:
|
| 301 |
+
ie_e = obtain_intersection_values(i_m_w, row_ind, col_ind, dic)
|
| 302 |
+
intersection_E[row_ind_] = ie_e.to(e_pred_showers.device)
|
| 303 |
+
pred_showers[index_matches] = -1
|
| 304 |
+
pred_showers[0] = -1
|
| 305 |
+
mask = pred_showers != -1
|
| 306 |
+
fakes_in_event = mask.sum()
|
| 307 |
+
fake_showers_e = e_pred_showers[mask]
|
| 308 |
+
fake_showers_e_hcal = e_pred_showers_hcal[mask]
|
| 309 |
+
fake_showers_e_ecal = e_pred_showers_ecal[mask]
|
| 310 |
+
number_of_fake_showers = mask.sum()
|
| 311 |
+
|
| 312 |
+
all_labels = labels.unique().to(e_pred_showers.device)
|
| 313 |
+
number_of_fake_showers = mask.sum()
|
| 314 |
+
fakes_labels = torch.where(mask)[0].to(e_pred_showers.device)
|
| 315 |
+
fake_showers_distance_to_cluster = distances[fakes_labels.cpu()]
|
| 316 |
+
fake_showers_num_tracks = number_of_tracks[fakes_labels.cpu()]
|
| 317 |
+
|
| 318 |
+
if e_corr is None or pandora:
|
| 319 |
+
fake_showers_e_cali = e_pred_showers_cali[mask]
|
| 320 |
+
else:
|
| 321 |
+
fakes_positions = pred_pos[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total + number_of_fake_showers]
|
| 322 |
+
fake_showers_e_cali = e_corr[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total + number_of_fake_showers]
|
| 323 |
+
fakes_pid_pred = pred_pid[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total + number_of_fake_showers]
|
| 324 |
+
fake_showers_e_reco = e_reco_showers[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total + number_of_fake_showers]
|
| 325 |
+
fakes_positions = fakes_positions.to(e_pred_showers.device)
|
| 326 |
+
fakes_extra_features = extra_features[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total + number_of_fake_showers]
|
| 327 |
+
fake_showers_e_cali = fake_showers_e_cali.to(e_pred_showers.device)
|
| 328 |
+
fakes_pid_pred = fakes_pid_pred.to(e_pred_showers.device)
|
| 329 |
+
fake_showers_e_reco = fake_showers_e_reco.to(e_pred_showers.device)
|
| 330 |
+
|
| 331 |
+
if pandora:
|
| 332 |
+
fake_pandora_pid = (torch.zeros((fake_showers_e.shape[0], 3)) * torch.nan).to(dev)
|
| 333 |
+
fake_pandora_pid = pandora_pid[mask]
|
| 334 |
+
if calc_pandora_momentum:
|
| 335 |
+
fake_positions_pfo = nan_tensor(fake_showers_e.shape[0], 3, device=dev)
|
| 336 |
+
fake_positions_pfo = pxyz_pred_pfo[mask]
|
| 337 |
+
fakes_positions_ref = nan_tensor(fake_showers_e.shape[0], 3, device=dev)
|
| 338 |
+
fakes_positions_ref = ref_pt_pred_pfo[mask]
|
| 339 |
+
if not pandora:
|
| 340 |
+
if e_corr is None:
|
| 341 |
+
fake_showers_e_cali_factor = corrections_per_shower[mask]
|
| 342 |
+
else:
|
| 343 |
+
fake_showers_e_cali_factor = fake_showers_e_cali
|
| 344 |
+
fake_showers_showers_e_truw = nan_tensor(fake_showers_e.shape[0], device=dev)
|
| 345 |
+
fake_showers_vertex = nan_tensor(fake_showers_e.shape[0], 3, device=dev)
|
| 346 |
+
fakes_is_track = (torch.zeros((fake_showers_e.shape[0])) * torch.nan).to(dev)
|
| 347 |
+
fakes_is_track = is_track_per_shower[mask]
|
| 348 |
+
fakes_positions_t = nan_tensor(fake_showers_e.shape[0], 3, device=dev)
|
| 349 |
+
if not pandora:
|
| 350 |
+
number_of_fake_showers_total = number_of_fake_showers_total + number_of_fake_showers
|
| 351 |
+
|
| 352 |
+
energy_t = torch.cat((energy_t, fake_showers_showers_e_truw), dim=0)
|
| 353 |
+
gen_status = torch.cat((gen_status, fake_showers_showers_e_truw), dim=0)
|
| 354 |
+
vertex = torch.cat((vertex, fake_showers_vertex), dim=0)
|
| 355 |
+
pid_t = torch.cat((pid_t.view(-1), fake_showers_showers_e_truw), dim=0)
|
| 356 |
+
pos_t = torch.cat((pos_t, fakes_positions_t), dim=0)
|
| 357 |
+
e_reco = torch.cat((e_reco_showers[1:], fake_showers_showers_e_truw), dim=0)
|
| 358 |
+
e_labels = torch.cat((e_label_showers[1:], 0 * fake_showers_showers_e_truw), dim=0)
|
| 359 |
+
is_track_in_MC = torch.cat((is_track_in_MC[1:], fake_showers_num_tracks.to(e_reco.device)), dim=0)
|
| 360 |
+
track_chi = torch.cat((track_chi[1:], fake_showers_num_tracks.to(e_reco.device)), dim=0)
|
| 361 |
+
distance_to_cluster_MC = torch.cat(
|
| 362 |
+
(distance_to_cluster_all[1:], fake_showers_distance_to_cluster.to(e_reco.device)), dim=0
|
| 363 |
+
)
|
| 364 |
+
e_pred = torch.cat((matched_es, fake_showers_e), dim=0)
|
| 365 |
+
e_pred_ECAL = torch.cat((matched_ECAL, fake_showers_e_ecal), dim=0)
|
| 366 |
+
e_pred_HCAL = torch.cat((matched_HCAL, fake_showers_e_hcal), dim=0)
|
| 367 |
+
e_pred_cali = torch.cat((matched_es_cali, fake_showers_e_cali), dim=0)
|
| 368 |
+
if pred_pos is not None:
|
| 369 |
+
e_pred_pos = torch.cat((matched_positions, fakes_positions), dim=0)
|
| 370 |
+
e_pred_pid = torch.cat((matched_pid, fakes_pid_pred), dim=0)
|
| 371 |
+
e_pred_ref_pt = torch.cat((matched_ref_pt, fakes_positions), dim=0)
|
| 372 |
+
extra_features_all = torch.cat(
|
| 373 |
+
(matched_extra_features, torch.tensor(fakes_extra_features)), dim=0
|
| 374 |
+
)
|
| 375 |
+
if pandora:
|
| 376 |
+
e_pred_cali_pfo = torch.cat((matched_es_cali_pfo, fake_showers_e_cali), dim=0)
|
| 377 |
+
positions_pfo = torch.cat((matched_positions_pfo, fake_positions_pfo), dim=0)
|
| 378 |
+
pandora_pid = torch.cat((matched_pandora_pid, fake_pandora_pid), dim=0)
|
| 379 |
+
ref_pts_pfo = torch.cat((matched_ref_pts_pfo, fakes_positions_ref), dim=0)
|
| 380 |
+
else:
|
| 381 |
+
cluster_removed_tracks = torch.cat((cluster_removed_tracks, 0 * fake_showers_e_cali), dim=0)
|
| 382 |
+
if not pandora:
|
| 383 |
+
calibration_factor = torch.cat((calibration_per_shower, fake_showers_e_cali_factor), dim=0)
|
| 384 |
+
|
| 385 |
+
e_pred_t = torch.cat(
|
| 386 |
+
(intersection_E, nan_like(fake_showers_e)),
|
| 387 |
+
dim=0,
|
| 388 |
+
)
|
| 389 |
+
is_track = torch.cat((is_track, fakes_is_track.to(is_track.device)), dim=0)
|
| 390 |
+
matched_es_tracks_1 = torch.cat(
|
| 391 |
+
(matched_es_tracks_1, 0 * fakes_is_track.to(is_track.device)), dim=0
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
# Build shared base dict, then update with pandora- or non-pandora-specific keys
|
| 395 |
+
d = {
|
| 396 |
+
"true_showers_E": energy_t.detach().cpu(),
|
| 397 |
+
"reco_showers_E": e_reco.detach().cpu(),
|
| 398 |
+
"pred_showers_E": e_pred.detach().cpu(),
|
| 399 |
+
"e_pred_and_truth": e_pred_t.detach().cpu(),
|
| 400 |
+
"pid": pid_t.detach().cpu(),
|
| 401 |
+
"step": torch.ones_like(energy_t.detach().cpu()) * step,
|
| 402 |
+
"number_batch": torch.ones_like(energy_t.detach().cpu()) * number_in_batch,
|
| 403 |
+
"is_track_in_cluster": is_track.detach().cpu(),
|
| 404 |
+
"is_track_correct": matched_es_tracks_1.detach().cpu(),
|
| 405 |
+
"is_track_in_MC": is_track_in_MC.detach().cpu(),
|
| 406 |
+
"track_chi": track_chi.detach().cpu(),
|
| 407 |
+
"distance_to_cluster_MC": distance_to_cluster_MC.detach().cpu(),
|
| 408 |
+
"vertex": vertex.detach().cpu().tolist(),
|
| 409 |
+
"ECAL_hits": e_pred_ECAL.detach().cpu(),
|
| 410 |
+
"HCAL_hits": e_pred_HCAL.detach().cpu(),
|
| 411 |
+
"gen_status": gen_status.detach().cpu(),
|
| 412 |
+
"labels": e_labels.detach().cpu(),
|
| 413 |
+
}
|
| 414 |
+
if pandora:
|
| 415 |
+
d.update({
|
| 416 |
+
"pandora_calibrated_E": e_pred_cali.detach().cpu(),
|
| 417 |
+
"pandora_calibrated_pfo": e_pred_cali_pfo.detach().cpu(),
|
| 418 |
+
"pandora_calibrated_pos": positions_pfo.detach().cpu().tolist(),
|
| 419 |
+
"pandora_ref_pt": ref_pts_pfo.detach().cpu().tolist(),
|
| 420 |
+
"pandora_pid": pandora_pid.detach().cpu(),
|
| 421 |
+
})
|
| 422 |
+
else:
|
| 423 |
+
d.update({
|
| 424 |
+
"calibration_factor": calibration_factor.detach().cpu(),
|
| 425 |
+
"calibrated_E": e_pred_cali.detach().cpu(),
|
| 426 |
+
"cluster_removed_tracks": cluster_removed_tracks.detach().cpu(),
|
| 427 |
+
})
|
| 428 |
+
if pred_pos is not None:
|
| 429 |
+
d["pred_pos_matched"] = e_pred_pos.detach().cpu().tolist()
|
| 430 |
+
d["pred_pid_matched"] = e_pred_pid.detach().cpu().tolist()
|
| 431 |
+
d["pred_ref_pt_matched"] = e_pred_ref_pt.detach().cpu().tolist()
|
| 432 |
+
d["matched_extra_features"] = extra_features_all.detach().cpu().tolist()
|
| 433 |
+
|
| 434 |
+
d["true_pos"] = pos_t.detach().cpu().tolist()
|
| 435 |
+
df = pd.DataFrame(data=d)
|
| 436 |
+
if number_of_showers_total is None:
|
| 437 |
+
return df
|
| 438 |
+
else:
|
| 439 |
+
return df, number_of_showers_total, number_of_fake_showers_total
|
| 440 |
+
else:
|
| 441 |
+
return [], 0, 0
|
src/layers/shower_matching.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shower matching utilities for particle-flow reconstruction."""
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from torch_scatter import scatter_add
|
| 5 |
+
from scipy.optimize import linear_sum_assignment
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CachedIndexList:
|
| 9 |
+
def __init__(self, lst):
|
| 10 |
+
self.lst = lst
|
| 11 |
+
self.cache = {}
|
| 12 |
+
|
| 13 |
+
def index(self, value):
|
| 14 |
+
if value in self.cache:
|
| 15 |
+
return self.cache[value]
|
| 16 |
+
else:
|
| 17 |
+
idx = self.lst.index(value)
|
| 18 |
+
self.cache[value] = idx
|
| 19 |
+
return idx
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_labels_pandora(dic, device):
|
| 23 |
+
labels_pandora = dic["graph"].ndata["pandora_pfo"].long()
|
| 24 |
+
labels_pandora = labels_pandora + 1
|
| 25 |
+
map_from = list(np.unique(labels_pandora.detach().cpu()))
|
| 26 |
+
map_from = CachedIndexList(map_from)
|
| 27 |
+
cluster_id = map(lambda x: map_from.index(x), labels_pandora.detach().cpu().numpy())
|
| 28 |
+
labels_pandora = torch.Tensor(list(cluster_id)).long().to(device)
|
| 29 |
+
return labels_pandora
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def obtain_intersection_matrix(shower_p_unique, particle_ids, labels, dic, e_hits):
|
| 33 |
+
len_pred_showers = len(shower_p_unique)
|
| 34 |
+
intersection_matrix = torch.zeros((len_pred_showers, len(particle_ids))).to(
|
| 35 |
+
shower_p_unique.device
|
| 36 |
+
)
|
| 37 |
+
intersection_matrix_w = torch.zeros((len_pred_showers, len(particle_ids))).to(
|
| 38 |
+
shower_p_unique.device
|
| 39 |
+
)
|
| 40 |
+
for index, id in enumerate(particle_ids):
|
| 41 |
+
counts = torch.zeros_like(labels)
|
| 42 |
+
mask_p = dic["graph"].ndata["particle_number"] == id
|
| 43 |
+
h_hits = e_hits.clone()
|
| 44 |
+
counts[mask_p] = 1
|
| 45 |
+
h_hits[~mask_p] = 0
|
| 46 |
+
intersection_matrix[:, index] = scatter_add(counts, labels)
|
| 47 |
+
intersection_matrix_w[:, index] = scatter_add(h_hits, labels.to(h_hits.device))
|
| 48 |
+
return intersection_matrix, intersection_matrix_w
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def obtain_union_matrix(shower_p_unique, particle_ids, labels, dic):
|
| 52 |
+
len_pred_showers = len(shower_p_unique)
|
| 53 |
+
union_matrix = torch.zeros((len_pred_showers, len(particle_ids)))
|
| 54 |
+
for index, id in enumerate(particle_ids):
|
| 55 |
+
counts = torch.zeros_like(labels)
|
| 56 |
+
mask_p = dic["graph"].ndata["particle_number"] == id
|
| 57 |
+
for index_pred, id_pred in enumerate(shower_p_unique):
|
| 58 |
+
mask_pred_p = labels == id_pred
|
| 59 |
+
mask_union = mask_pred_p + mask_p
|
| 60 |
+
union_matrix[index_pred, index] = torch.sum(mask_union)
|
| 61 |
+
return union_matrix
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def obtain_intersection_values(intersection_matrix_w, row_ind, col_ind, dic):
|
| 65 |
+
list_intersection_E = []
|
| 66 |
+
particle_ids = torch.unique(dic["graph"].ndata["particle_number"])
|
| 67 |
+
if torch.sum(particle_ids == 0) > 0:
|
| 68 |
+
intersection_matrix_wt = torch.transpose(intersection_matrix_w[1:, 1:], 1, 0)
|
| 69 |
+
row_ind = row_ind - 1
|
| 70 |
+
else:
|
| 71 |
+
intersection_matrix_wt = torch.transpose(intersection_matrix_w[1:, :], 1, 0)
|
| 72 |
+
for i in range(0, len(col_ind)):
|
| 73 |
+
list_intersection_E.append(
|
| 74 |
+
intersection_matrix_wt[row_ind[i], col_ind[i]].view(-1)
|
| 75 |
+
)
|
| 76 |
+
if len(list_intersection_E) > 0:
|
| 77 |
+
return torch.cat(list_intersection_E, dim=0)
|
| 78 |
+
else:
|
| 79 |
+
return 0
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def match_showers(
|
| 83 |
+
labels,
|
| 84 |
+
dic,
|
| 85 |
+
particle_ids,
|
| 86 |
+
model_output,
|
| 87 |
+
local_rank,
|
| 88 |
+
i,
|
| 89 |
+
path_save,
|
| 90 |
+
pandora=False,
|
| 91 |
+
hdbscan=False,
|
| 92 |
+
):
|
| 93 |
+
iou_threshold = 0.25
|
| 94 |
+
shower_p_unique = torch.unique(labels)
|
| 95 |
+
if torch.sum(labels == 0) == 0:
|
| 96 |
+
shower_p_unique = torch.cat(
|
| 97 |
+
(
|
| 98 |
+
torch.Tensor([0]).to(shower_p_unique.device).view(-1),
|
| 99 |
+
shower_p_unique.view(-1),
|
| 100 |
+
),
|
| 101 |
+
dim=0,
|
| 102 |
+
)
|
| 103 |
+
e_hits = dic["graph"].ndata["e_hits"].view(-1)
|
| 104 |
+
i_m, i_m_w = obtain_intersection_matrix(
|
| 105 |
+
shower_p_unique, particle_ids, labels, dic, e_hits
|
| 106 |
+
)
|
| 107 |
+
i_m = i_m.to(model_output.device)
|
| 108 |
+
i_m_w = i_m_w.to(model_output.device)
|
| 109 |
+
u_m = obtain_union_matrix(shower_p_unique, particle_ids, labels, dic)
|
| 110 |
+
u_m = u_m.to(model_output.device)
|
| 111 |
+
iou_matrix = i_m / u_m
|
| 112 |
+
if torch.sum(particle_ids == 0) > 0:
|
| 113 |
+
iou_matrix_num = (
|
| 114 |
+
torch.transpose(iou_matrix[1:, 1:], 1, 0).clone().detach().cpu().numpy()
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
iou_matrix_num = (
|
| 118 |
+
torch.transpose(iou_matrix[1:, :], 1, 0).clone().detach().cpu().numpy()
|
| 119 |
+
)
|
| 120 |
+
iou_matrix_num[iou_matrix_num < iou_threshold] = 0
|
| 121 |
+
row_ind, col_ind = linear_sum_assignment(-iou_matrix_num)
|
| 122 |
+
mask_matching_matrix = iou_matrix_num[row_ind, col_ind] > 0
|
| 123 |
+
row_ind = row_ind[mask_matching_matrix]
|
| 124 |
+
col_ind = col_ind[mask_matching_matrix]
|
| 125 |
+
if torch.sum(particle_ids == 0) > 0:
|
| 126 |
+
row_ind = row_ind + 1
|
| 127 |
+
return shower_p_unique, row_ind, col_ind, i_m_w, iou_matrix
|
src/layers/tools_for_regression.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from torch_scatter import scatter_mean, scatter_sum
|
| 4 |
+
|
| 5 |
+
def pick_lowest_chi_squared(pxpypz, chi_s, batch_idx, xyz_nodes):
|
| 6 |
+
unique_batch = torch.unique(batch_idx)
|
| 7 |
+
p_direction = []
|
| 8 |
+
track_xyz = []
|
| 9 |
+
for i in range(0, len(unique_batch)):
|
| 10 |
+
mask = batch_idx == unique_batch[i]
|
| 11 |
+
if torch.sum(mask) > 1:
|
| 12 |
+
chis = chi_s[mask]
|
| 13 |
+
ind_min = torch.argmin(chis)
|
| 14 |
+
p_direction.append(pxpypz[mask][ind_min].view(-1, 3))
|
| 15 |
+
track_xyz.append(xyz_nodes[mask][ind_min].view(-1, 3))
|
| 16 |
+
|
| 17 |
+
else:
|
| 18 |
+
p_direction.append(pxpypz[mask].view(-1, 3))
|
| 19 |
+
track_xyz.append(xyz_nodes[mask].view(-1, 3))
|
| 20 |
+
return torch.concat(p_direction, dim=0), torch.stack(track_xyz)[:, 0]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class AverageHitsP(torch.nn.Module):
|
| 25 |
+
# Same layout of the module as the GNN one, but just computes the average of the hits. Try to compare this + ML clustering with Pandora
|
| 26 |
+
def __init__(self, ecal_only=False):
|
| 27 |
+
super(AverageHitsP, self).__init__()
|
| 28 |
+
self.ecal_only = ecal_only
|
| 29 |
+
def predict(self, x_global_features, graphs_new=None, explain=False):
|
| 30 |
+
"""
|
| 31 |
+
Forward, named 'predict' for compatibility reasons
|
| 32 |
+
:param x_global_features: Global features of the graphs - to be concatenated to each node feature
|
| 33 |
+
:param graphs_new:
|
| 34 |
+
:return:
|
| 35 |
+
"""
|
| 36 |
+
assert graphs_new is not None
|
| 37 |
+
batch_num_nodes = graphs_new.batch_num_nodes() # Num. of hits in each graph
|
| 38 |
+
batch_idx = []
|
| 39 |
+
batch_bounds = []
|
| 40 |
+
if self.ecal_only:
|
| 41 |
+
mask_ecal_only = [] # whether to consider only ECAL or ECAL+HCAL
|
| 42 |
+
for i, n in enumerate(batch_num_nodes):
|
| 43 |
+
batch_idx.extend([i] * n)
|
| 44 |
+
batch_bounds.append(n)
|
| 45 |
+
batch_idx = np.array(batch_idx)
|
| 46 |
+
for i in range(len(np.unique(batch_idx))):
|
| 47 |
+
if self.ecal_only:
|
| 48 |
+
n_ecal_hits = (graphs_new.ndata["h"][batch_idx == i, 5] > 0).sum()
|
| 49 |
+
n_hcal_hits = (graphs_new.ndata["h"][batch_idx == i, 6] > 0).sum()
|
| 50 |
+
for _ in range(batch_num_nodes[i]):
|
| 51 |
+
mask_ecal_only.append((n_ecal_hits / (n_hcal_hits + n_ecal_hits)).item())
|
| 52 |
+
batch_idx = torch.tensor(batch_idx).to(graphs_new.device)
|
| 53 |
+
if self.ecal_only:
|
| 54 |
+
mask_ecal_only = torch.tensor(mask_ecal_only) # round().int().bool().to(graphs_new.device)
|
| 55 |
+
mask_ecal_only = (mask_ecal_only > 0.05).int().bool().to(graphs_new.device)
|
| 56 |
+
#mask_ecal_only=torch.zeros(len(mask_ecal_only)).bool().to(graphs_new.device)
|
| 57 |
+
xyz_hits = graphs_new.ndata["h"][:, :3]
|
| 58 |
+
E_hits = graphs_new.ndata["h"][:, 8]
|
| 59 |
+
if self.ecal_only:
|
| 60 |
+
hcal_hits = graphs_new.ndata["h"][:, 6] > 0
|
| 61 |
+
E_hits[mask_ecal_only & (hcal_hits)] = 0
|
| 62 |
+
weighted_avg_hits = scatter_sum(xyz_hits * E_hits.unsqueeze(1), batch_idx, dim=0)
|
| 63 |
+
E_total = scatter_sum(E_hits, batch_idx, dim=0)
|
| 64 |
+
p_direction = weighted_avg_hits / E_total.unsqueeze(1)
|
| 65 |
+
p_tracks = torch.norm(p_direction, dim=1)
|
| 66 |
+
p_direction = p_direction / torch.norm(p_direction, dim=1).unsqueeze(1)
|
| 67 |
+
# if self.pos_regression:
|
| 68 |
+
return p_tracks, p_direction, weighted_avg_hits / E_total.unsqueeze(1) * 3300 # Reference point
|
| 69 |
+
# return p_tracks
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class PickPAtDCA(torch.nn.Module):
|
| 74 |
+
# Same layout of the module as the GNN one, but just picks the track
|
| 75 |
+
def __init__(self):
|
| 76 |
+
super(PickPAtDCA, self).__init__()
|
| 77 |
+
|
| 78 |
+
def predict(self, x_global_features, graphs_new=None, explain=False):
|
| 79 |
+
"""
|
| 80 |
+
Forward, named 'predict' for compatibility reasons
|
| 81 |
+
:param x_global_features: Global features of the graphs - to be concatenated to each node feature
|
| 82 |
+
:param graphs_new:
|
| 83 |
+
:return:
|
| 84 |
+
"""
|
| 85 |
+
assert graphs_new is not None
|
| 86 |
+
batch_num_nodes = graphs_new.batch_num_nodes()
|
| 87 |
+
batch_idx = []
|
| 88 |
+
batch_bounds = []
|
| 89 |
+
for i, n in enumerate(batch_num_nodes):
|
| 90 |
+
batch_idx.extend([i] * n)
|
| 91 |
+
batch_bounds.append(n)
|
| 92 |
+
batch_idx = torch.tensor(batch_idx).to(graphs_new.device)
|
| 93 |
+
|
| 94 |
+
ht = graphs_new.ndata["h"][:, 3:7].argmax(dim=1)
|
| 95 |
+
filt = ht == 1 # track
|
| 96 |
+
filt_hits = ((ht == 2) + (ht == 3)).bool()
|
| 97 |
+
|
| 98 |
+
p_direction, p_xyz = pick_lowest_chi_squared(
|
| 99 |
+
graphs_new.ndata["pos_pxpypz_at_vertex"][filt],
|
| 100 |
+
graphs_new.ndata["chi_squared_tracks"][filt],
|
| 101 |
+
batch_idx[filt],
|
| 102 |
+
graphs_new.ndata["h"][filt, :3]
|
| 103 |
+
)
|
| 104 |
+
# Barycenters of clusters of hits
|
| 105 |
+
xyz_hits = graphs_new.ndata["h"][:, :3]
|
| 106 |
+
E_hits = graphs_new.ndata["h"][:, 8]
|
| 107 |
+
weighted_avg_hits = scatter_sum(xyz_hits * E_hits.unsqueeze(1), batch_idx, dim=0)
|
| 108 |
+
E_total = scatter_sum(E_hits, batch_idx, dim=0)
|
| 109 |
+
barycenters = weighted_avg_hits / E_total.unsqueeze(1)
|
| 110 |
+
p_tracks = torch.norm(p_direction, dim=1)
|
| 111 |
+
return p_tracks, p_direction, barycenters - p_xyz
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ECNetWrapperAvg(torch.nn.Module):
|
| 116 |
+
# use the GNN+NN model for energy correction
|
| 117 |
+
# This one concatenates GNN features to the global features
|
| 118 |
+
def __init__(self):
|
| 119 |
+
super(ECNetWrapperAvg, self).__init__()
|
| 120 |
+
self.AvgHits = AverageHitsP(ecal_only=True)
|
| 121 |
+
|
| 122 |
+
def predict(self, x_global_features, graphs_new=None, explain=False):
|
| 123 |
+
"""
|
| 124 |
+
Forward, named 'predict' for compatibility reasons
|
| 125 |
+
:param x_global_features: Global features of the graphs - to be concatenated to each node feature
|
| 126 |
+
:param graphs_new:
|
| 127 |
+
:return:
|
| 128 |
+
"""
|
| 129 |
+
_, p_pred, _ = self.AvgHits.predict(x_global_features, graphs_new)
|
| 130 |
+
p_pred = (p_pred / torch.norm(p_pred, dim=1).unsqueeze(1)).clone()
|
| 131 |
+
return None, p_pred, None, None
|
src/layers/utils_training.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from lightning.pytorch.callbacks import BaseFinetuning
|
| 3 |
+
import torch
|
| 4 |
+
import dgl
|
| 5 |
+
from src.layers.inference_oc import DPC_custom_CLD
|
| 6 |
+
from src.layers.inference_oc import match_showers
|
| 7 |
+
from src.layers.inference_oc import remove_bad_tracks_from_cluster
|
| 8 |
+
class FreezeClustering(BaseFinetuning):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
):
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
def freeze_before_training(self, pl_module):
|
| 15 |
+
self.freeze(pl_module.ScaledGooeyBatchNorm2_1)
|
| 16 |
+
self.freeze(pl_module.gatr)
|
| 17 |
+
self.freeze(pl_module.clustering)
|
| 18 |
+
self.freeze(pl_module.beta)
|
| 19 |
+
|
| 20 |
+
print("CLUSTERING HAS BEEN FROOOZEN")
|
| 21 |
+
|
| 22 |
+
def finetune_function(self, pl_module, current_epoch, optimizer):
|
| 23 |
+
print("Not finetunning")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def obtain_batch_numbers(x, g):
|
| 28 |
+
dev = x.device
|
| 29 |
+
graphs_eval = dgl.unbatch(g)
|
| 30 |
+
number_graphs = len(graphs_eval)
|
| 31 |
+
batch_numbers = []
|
| 32 |
+
for index in range(0, number_graphs):
|
| 33 |
+
gj = graphs_eval[index]
|
| 34 |
+
num_nodes = gj.number_of_nodes()
|
| 35 |
+
batch_numbers.append(index * torch.ones(num_nodes).to(dev))
|
| 36 |
+
# num_nodes = gj.number_of_nodes()
|
| 37 |
+
|
| 38 |
+
batch = torch.cat(batch_numbers, dim=0)
|
| 39 |
+
return batch
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def obtain_clustering_for_matched_showers(
|
| 44 |
+
batch_g, model_output, y_all, local_rank, use_gt_clusters=False, add_fakes=True
|
| 45 |
+
):
|
| 46 |
+
|
| 47 |
+
graphs_showers_matched = []
|
| 48 |
+
graphs_showers_fakes = []
|
| 49 |
+
true_energy_showers = []
|
| 50 |
+
reco_energy_showers = []
|
| 51 |
+
reco_energy_showers_fakes = []
|
| 52 |
+
energy_true_daughters = []
|
| 53 |
+
y_pids_matched = []
|
| 54 |
+
y_coords_matched = []
|
| 55 |
+
if not use_gt_clusters:
|
| 56 |
+
batch_g.ndata["coords"] = model_output[:, 0:3]
|
| 57 |
+
batch_g.ndata["beta"] = model_output[:, 3]
|
| 58 |
+
graphs = dgl.unbatch(batch_g)
|
| 59 |
+
batch_id = y_all.batch_number
|
| 60 |
+
for i in range(0, len(graphs)):
|
| 61 |
+
mask = batch_id == i
|
| 62 |
+
dic = {}
|
| 63 |
+
dic["graph"] = graphs[i]
|
| 64 |
+
y = y_all.copy()
|
| 65 |
+
|
| 66 |
+
y.mask(mask.flatten())
|
| 67 |
+
dic["part_true"] = y
|
| 68 |
+
if not use_gt_clusters:
|
| 69 |
+
betas = torch.sigmoid(dic["graph"].ndata["beta"])
|
| 70 |
+
X = dic["graph"].ndata["coords"]
|
| 71 |
+
|
| 72 |
+
if use_gt_clusters:
|
| 73 |
+
labels = dic["graph"].ndata["particle_number"].type(torch.int64)
|
| 74 |
+
else:
|
| 75 |
+
labels =DPC_custom_CLD(X, dic["graph"], model_output.device)
|
| 76 |
+
|
| 77 |
+
labels, _ = remove_bad_tracks_from_cluster(dic["graph"], labels)
|
| 78 |
+
particle_ids = torch.unique(dic["graph"].ndata["particle_number"])
|
| 79 |
+
shower_p_unique = torch.unique(labels)
|
| 80 |
+
shower_p_unique, row_ind, col_ind, i_m_w, _ = match_showers(
|
| 81 |
+
labels, dic, particle_ids, model_output, local_rank, i, None
|
| 82 |
+
)
|
| 83 |
+
row_ind = torch.Tensor(row_ind).to(model_output.device).long()
|
| 84 |
+
col_ind = torch.Tensor(col_ind).to(model_output.device).long()
|
| 85 |
+
if torch.sum(particle_ids == 0) > 0:
|
| 86 |
+
row_ind_ = row_ind - 1
|
| 87 |
+
else:
|
| 88 |
+
# if there is no zero then index 0 corresponds to particle 1.
|
| 89 |
+
row_ind_ = row_ind
|
| 90 |
+
index_matches = col_ind + 1
|
| 91 |
+
index_matches = index_matches.to(model_output.device).long()
|
| 92 |
+
|
| 93 |
+
for j, unique_showers_label in enumerate(index_matches):
|
| 94 |
+
if torch.sum(unique_showers_label == index_matches) == 1:
|
| 95 |
+
index_in_matched = torch.argmax(
|
| 96 |
+
(unique_showers_label == index_matches) * 1
|
| 97 |
+
)
|
| 98 |
+
mask = labels == unique_showers_label
|
| 99 |
+
sls_graph = graphs[i].ndata["pos_hits_xyz"][mask][:, 0:3]
|
| 100 |
+
g = dgl.graph(([], []))
|
| 101 |
+
g.add_nodes(sls_graph.shape[0])
|
| 102 |
+
g = g.to(sls_graph.device)
|
| 103 |
+
g.ndata["h"] = graphs[i].ndata["h"][mask]
|
| 104 |
+
if "pos_pxpypz" in graphs[i].ndata:
|
| 105 |
+
g.ndata["pos_pxpypz"] = graphs[i].ndata["pos_pxpypz"][mask]
|
| 106 |
+
if "pos_pxpypz_at_vertex" in graphs[i].ndata:
|
| 107 |
+
g.ndata["pos_pxpypz_at_vertex"] = graphs[i].ndata[
|
| 108 |
+
"pos_pxpypz_at_vertex"
|
| 109 |
+
][mask]
|
| 110 |
+
g.ndata["chi_squared_tracks"] = graphs[i].ndata["chi_squared_tracks"][mask]
|
| 111 |
+
energy_t = dic["part_true"].E.to(model_output.device)
|
| 112 |
+
energy_t_corr_daughters = dic["part_true"].m.to(
|
| 113 |
+
model_output.device
|
| 114 |
+
)
|
| 115 |
+
true_energy_shower = energy_t[row_ind_[j]]
|
| 116 |
+
y_pids_matched.append(y.pid[row_ind_[j]].item())
|
| 117 |
+
y_coords_matched.append(y.coord[row_ind_[j]].detach().cpu().numpy())
|
| 118 |
+
energy_true_daughters.append(energy_t_corr_daughters[row_ind_[j]])
|
| 119 |
+
reco_energy_shower = torch.sum(graphs[i].ndata["e_hits"][mask])
|
| 120 |
+
graphs_showers_matched.append(g)
|
| 121 |
+
true_energy_showers.append(true_energy_shower.view(-1))
|
| 122 |
+
reco_energy_showers.append(reco_energy_shower.view(-1))
|
| 123 |
+
pred_showers = shower_p_unique
|
| 124 |
+
pred_showers[index_matches] = -1
|
| 125 |
+
pred_showers[
|
| 126 |
+
0
|
| 127 |
+
] = (
|
| 128 |
+
-1
|
| 129 |
+
)
|
| 130 |
+
mask_fakes = pred_showers != -1
|
| 131 |
+
fakes_idx = torch.where(mask_fakes)[0]
|
| 132 |
+
if add_fakes:
|
| 133 |
+
for j in fakes_idx:
|
| 134 |
+
mask = labels == j
|
| 135 |
+
sls_graph = graphs[i].ndata["pos_hits_xyz"][mask][:, 0:3]
|
| 136 |
+
g = dgl.graph(([], []))
|
| 137 |
+
g.add_nodes(sls_graph.shape[0])
|
| 138 |
+
g = g.to(sls_graph.device)
|
| 139 |
+
|
| 140 |
+
g.ndata["h"] = graphs[i].ndata["h"][mask]
|
| 141 |
+
|
| 142 |
+
if "pos_pxpypz" in graphs[i].ndata:
|
| 143 |
+
g.ndata["pos_pxpypz"] = graphs[i].ndata["pos_pxpypz"][mask]
|
| 144 |
+
if "pos_pxpypz_at_vertex" in graphs[i].ndata:
|
| 145 |
+
g.ndata["pos_pxpypz_at_vertex"] = graphs[i].ndata[
|
| 146 |
+
"pos_pxpypz_at_vertex"
|
| 147 |
+
][mask]
|
| 148 |
+
g.ndata["chi_squared_tracks"] = graphs[i].ndata["chi_squared_tracks"][mask]
|
| 149 |
+
graphs_showers_fakes.append(g)
|
| 150 |
+
reco_energy_shower = torch.sum(graphs[i].ndata["e_hits"][mask])
|
| 151 |
+
reco_energy_showers_fakes.append(reco_energy_shower.view(-1))
|
| 152 |
+
graphs_showers_matched = dgl.batch(graphs_showers_matched + graphs_showers_fakes)
|
| 153 |
+
true_energy_showers = torch.cat(true_energy_showers, dim=0)
|
| 154 |
+
reco_energy_showers = torch.cat(reco_energy_showers + reco_energy_showers_fakes, dim=0)
|
| 155 |
+
e_true_corr_daughters = torch.cat(energy_true_daughters, dim=0)
|
| 156 |
+
number_of_fakes = len(reco_energy_showers_fakes)
|
| 157 |
+
return (
|
| 158 |
+
graphs_showers_matched,
|
| 159 |
+
true_energy_showers,
|
| 160 |
+
reco_energy_showers,
|
| 161 |
+
y_pids_matched,
|
| 162 |
+
e_true_corr_daughters,
|
| 163 |
+
y_coords_matched,
|
| 164 |
+
number_of_fakes,
|
| 165 |
+
fakes_idx
|
| 166 |
+
)
|
src/models/E_correction_module.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import io
|
| 4 |
+
import pickle
|
| 5 |
+
|
| 6 |
+
class Net(nn.Module):
|
| 7 |
+
def __init__(self, in_features=13, out_features=1, return_raw=True):
|
| 8 |
+
super(Net, self).__init__()
|
| 9 |
+
self.out_features = out_features
|
| 10 |
+
self.return_raw = return_raw
|
| 11 |
+
self.model = nn.ModuleList(
|
| 12 |
+
[
|
| 13 |
+
# nn.BatchNorm1d(13),
|
| 14 |
+
nn.Linear(in_features, 64),
|
| 15 |
+
nn.ReLU(),
|
| 16 |
+
nn.Linear(64, 64),
|
| 17 |
+
# nn.BatchNorm1d(64),
|
| 18 |
+
nn.ReLU(),
|
| 19 |
+
nn.Linear(64, 64),
|
| 20 |
+
nn.ReLU(),
|
| 21 |
+
nn.Linear(64, out_features),
|
| 22 |
+
]
|
| 23 |
+
)
|
| 24 |
+
self.explainer_mode = False
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
if not isinstance(x, torch.Tensor):
|
| 28 |
+
x = torch.tensor(x)
|
| 29 |
+
for layer in self.model:
|
| 30 |
+
x = layer(x)
|
| 31 |
+
if self.out_features > 1 and not self.return_raw:
|
| 32 |
+
return x[:, 0], x[:, 1:]
|
| 33 |
+
if self.explainer_mode:
|
| 34 |
+
return x.numpy()
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
def freeze_batchnorm(self):
|
| 38 |
+
for layer in self.model:
|
| 39 |
+
if isinstance(layer, nn.BatchNorm1d):
|
| 40 |
+
layer.eval()
|
| 41 |
+
print("Frozen batchnorm in 1st layer only - ", layer)
|
| 42 |
+
break
|
| 43 |
+
|
src/models/Gatr_pf_e_noise.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file includes code adapted from:
|
| 3 |
+
|
| 4 |
+
Geometric Algebra Transformer (GATr)
|
| 5 |
+
https://github.com/Qualcomm-AI-research/geometric-algebra-transformer
|
| 6 |
+
|
| 7 |
+
The original implementation is by Qualcomm AI Research. It has been modified
|
| 8 |
+
and integrated into this project for particle-flow reconstruction at the
|
| 9 |
+
CLD detector (FCC-ee). Please refer to the original repository for
|
| 10 |
+
authorship, documentation, and license information.
|
| 11 |
+
"""
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import dgl
|
| 15 |
+
from src.layers.object_cond import object_condensation_loss2
|
| 16 |
+
from src.models.energy_correction_NN import EnergyCorrection
|
| 17 |
+
from src.layers.inference_oc import create_and_store_graph_output
|
| 18 |
+
import lightning as L
|
| 19 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 20 |
+
from xformers.ops.fmha import BlockDiagonalMask
|
| 21 |
+
import os
|
| 22 |
+
import wandb
|
| 23 |
+
from gatr import GATr, SelfAttentionConfig, MLPConfig
|
| 24 |
+
from gatr.interface import embed_point, extract_scalar, extract_point, embed_scalar
|
| 25 |
+
from src.utils.logger_wandb import log_losses_wandb
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ExampleWrapper(L.LightningModule):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
args,
|
| 32 |
+
dev,
|
| 33 |
+
blocks=10,
|
| 34 |
+
hidden_mv_channels=16,
|
| 35 |
+
hidden_s_channels=64,
|
| 36 |
+
config=None
|
| 37 |
+
):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.strict_loading = False
|
| 40 |
+
self.input_dim = 3
|
| 41 |
+
self.output_dim = 4
|
| 42 |
+
self.loss_final = 0
|
| 43 |
+
self.number_b = 0
|
| 44 |
+
self.df_showers = []
|
| 45 |
+
self.df_showers_pandora = []
|
| 46 |
+
self.df_showers_db = []
|
| 47 |
+
self.args = args
|
| 48 |
+
self.dev = dev
|
| 49 |
+
self.config = config
|
| 50 |
+
self.gatr = GATr(
|
| 51 |
+
in_mv_channels=1,
|
| 52 |
+
out_mv_channels=1,
|
| 53 |
+
hidden_mv_channels=hidden_mv_channels,
|
| 54 |
+
in_s_channels=2,
|
| 55 |
+
out_s_channels=1,
|
| 56 |
+
hidden_s_channels=hidden_s_channels,
|
| 57 |
+
num_blocks=blocks,
|
| 58 |
+
attention=SelfAttentionConfig(),
|
| 59 |
+
mlp=MLPConfig(),
|
| 60 |
+
)
|
| 61 |
+
self.ScaledGooeyBatchNorm2_1 = nn.BatchNorm1d(self.input_dim, momentum=0.1)
|
| 62 |
+
self.clustering = nn.Linear(3, self.output_dim - 1, bias=False)
|
| 63 |
+
self.beta = nn.Linear(2, 1)
|
| 64 |
+
if self.args.correction:
|
| 65 |
+
self.energy_correction = EnergyCorrection(self)
|
| 66 |
+
self.ec_model_wrapper_charged = self.energy_correction.model_charged
|
| 67 |
+
self.ec_model_wrapper_neutral = self.energy_correction.model_neutral
|
| 68 |
+
self.pids_neutral = self.energy_correction.pids_neutral
|
| 69 |
+
self.pids_charged = self.energy_correction.pids_charged
|
| 70 |
+
else:
|
| 71 |
+
self.pids_neutral = []
|
| 72 |
+
self.pids_charged = []
|
| 73 |
+
|
| 74 |
+
def forward(self, g, y, step_count, eval="", return_train=False, use_gt_clusters=False):
|
| 75 |
+
if not use_gt_clusters:
|
| 76 |
+
inputs = g.ndata["pos_hits_xyz"].float()
|
| 77 |
+
inputs_scalar = g.ndata["hit_type"].float().view(-1, 1)
|
| 78 |
+
inputs = self.ScaledGooeyBatchNorm2_1(inputs)
|
| 79 |
+
embedded_inputs = embed_point(inputs) + embed_scalar(inputs_scalar)
|
| 80 |
+
embedded_inputs = embedded_inputs.unsqueeze(-2) # (N, 1, 16)
|
| 81 |
+
mask = self.build_attention_mask(g)
|
| 82 |
+
scalars = torch.cat((g.ndata["e_hits"].float(), g.ndata["p_hits"].float()), dim=1)
|
| 83 |
+
embedded_outputs, scalar_outputs = self.gatr(
|
| 84 |
+
embedded_inputs, scalars=scalars, attention_mask=mask
|
| 85 |
+
)
|
| 86 |
+
points = extract_point(embedded_outputs[:, 0, :])
|
| 87 |
+
nodewise_outputs = extract_scalar(embedded_outputs) # (N, 1, 1)
|
| 88 |
+
x_point = points
|
| 89 |
+
x_scalar = torch.cat(
|
| 90 |
+
(nodewise_outputs.view(-1, 1), scalar_outputs.view(-1, 1)), dim=1
|
| 91 |
+
)
|
| 92 |
+
x_cluster_coord = self.clustering(x_point)
|
| 93 |
+
beta = self.beta(x_scalar)
|
| 94 |
+
g.ndata["final_cluster"] = x_cluster_coord
|
| 95 |
+
g.ndata["beta"] = beta.view(-1)
|
| 96 |
+
x = torch.cat((x_cluster_coord, beta.view(-1, 1)), dim=1)
|
| 97 |
+
else:
|
| 98 |
+
x = torch.ones_like(g.ndata["h"][:, 0:4])
|
| 99 |
+
|
| 100 |
+
if self.args.correction:
|
| 101 |
+
result = self.energy_correction.forward_correction(g, x, y, return_train)
|
| 102 |
+
return result
|
| 103 |
+
else:
|
| 104 |
+
pred_energy_corr = torch.ones_like(beta.view(-1, 1))
|
| 105 |
+
return x, pred_energy_corr, 0, 0
|
| 106 |
+
|
| 107 |
+
def build_attention_mask(self, g):
|
| 108 |
+
batch_numbers = obtain_batch_numbers(g)
|
| 109 |
+
return BlockDiagonalMask.from_seqlens(
|
| 110 |
+
torch.bincount(batch_numbers.long()).tolist()
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def unfreeze_all(self):
|
| 114 |
+
for p in self.energy_correction.model_charged.parameters():
|
| 115 |
+
p.requires_grad = True
|
| 116 |
+
for p in self.energy_correction.model_neutral.gatr_pid.parameters():
|
| 117 |
+
p.requires_grad = True
|
| 118 |
+
for p in self.energy_correction.model_neutral.PID_head.parameters():
|
| 119 |
+
p.requires_grad = True
|
| 120 |
+
|
| 121 |
+
def training_step(self, batch, batch_idx):
|
| 122 |
+
y = batch[1]
|
| 123 |
+
batch_g = batch[0]
|
| 124 |
+
if self.trainer.is_global_zero:
|
| 125 |
+
result = self(batch_g, y, batch_idx)
|
| 126 |
+
else:
|
| 127 |
+
result = self(batch_g, y, 1)
|
| 128 |
+
|
| 129 |
+
model_output = result[0]
|
| 130 |
+
e_cor = result[1]
|
| 131 |
+
(loss, losses) = object_condensation_loss2(
|
| 132 |
+
batch_g,
|
| 133 |
+
model_output,
|
| 134 |
+
e_cor,
|
| 135 |
+
y,
|
| 136 |
+
q_min=self.args.qmin,
|
| 137 |
+
use_average_cc_pos=self.args.use_average_cc_pos,
|
| 138 |
+
)
|
| 139 |
+
if self.args.correction:
|
| 140 |
+
self.energy_correction.global_step = self.global_step
|
| 141 |
+
fixed = self.current_epoch > 0
|
| 142 |
+
loss_EC, loss_pos, loss_neutral_pid, loss_charged_pid = self.energy_correction.get_loss(
|
| 143 |
+
batch_g, y, result, self.stats, fixed
|
| 144 |
+
)
|
| 145 |
+
loss = loss_EC + loss_neutral_pid + loss_charged_pid
|
| 146 |
+
|
| 147 |
+
if self.trainer.is_global_zero:
|
| 148 |
+
log_losses_wandb(True, batch_idx, 0, losses, loss)
|
| 149 |
+
self.loss_final = loss.item() + self.loss_final
|
| 150 |
+
self.number_b = self.number_b + 1
|
| 151 |
+
del model_output
|
| 152 |
+
del e_cor
|
| 153 |
+
del losses
|
| 154 |
+
return loss
|
| 155 |
+
|
| 156 |
+
def validation_step(self, batch, batch_idx):
|
| 157 |
+
self.create_paths()
|
| 158 |
+
y = batch[1]
|
| 159 |
+
batch_g = batch[0]
|
| 160 |
+
shap_vals, ec_x = None, None
|
| 161 |
+
if self.args.correction:
|
| 162 |
+
result = self(batch_g, y, 1, use_gt_clusters=self.args.use_gt_clusters)
|
| 163 |
+
model_output = result[0]
|
| 164 |
+
outputs = self.energy_correction.get_validation_step_outputs(batch_g, y, result)
|
| 165 |
+
e_cor1, pred_pos, pred_ref_pt, pred_pid, num_fakes, extra_features, fakes_labels = outputs
|
| 166 |
+
e_cor = e_cor1
|
| 167 |
+
else:
|
| 168 |
+
model_output, e_cor1, loss_ll, _ = self(batch_g, y, 1)
|
| 169 |
+
e_cor1 = torch.ones_like(model_output[:, 0].view(-1, 1))
|
| 170 |
+
e_cor = e_cor1
|
| 171 |
+
pred_pos = None
|
| 172 |
+
pred_pid = None
|
| 173 |
+
pred_ref_pt = None
|
| 174 |
+
num_fakes = None
|
| 175 |
+
extra_features = None
|
| 176 |
+
fakes_labels = None
|
| 177 |
+
|
| 178 |
+
if self.args.predict:
|
| 179 |
+
if self.args.correction:
|
| 180 |
+
model_output1 = model_output
|
| 181 |
+
e_corr = e_cor
|
| 182 |
+
else:
|
| 183 |
+
model_output1 = torch.cat((model_output, e_cor.view(-1, 1)), dim=1)
|
| 184 |
+
e_corr = None
|
| 185 |
+
|
| 186 |
+
(
|
| 187 |
+
df_batch_pandora,
|
| 188 |
+
df_batch1,
|
| 189 |
+
self.total_number_events,
|
| 190 |
+
) = create_and_store_graph_output(
|
| 191 |
+
batch_g,
|
| 192 |
+
model_output1,
|
| 193 |
+
y,
|
| 194 |
+
0,
|
| 195 |
+
batch_idx,
|
| 196 |
+
0,
|
| 197 |
+
path_save=self.show_df_eval_path,
|
| 198 |
+
store=True,
|
| 199 |
+
predict=True,
|
| 200 |
+
e_corr=e_corr,
|
| 201 |
+
ec_x=ec_x,
|
| 202 |
+
total_number_events=self.total_number_events,
|
| 203 |
+
pred_pos=pred_pos,
|
| 204 |
+
pred_ref_pt=pred_ref_pt,
|
| 205 |
+
pred_pid=pred_pid,
|
| 206 |
+
use_gt_clusters=self.args.use_gt_clusters,
|
| 207 |
+
number_of_fakes=num_fakes,
|
| 208 |
+
extra_features=extra_features,
|
| 209 |
+
fakes_labels=fakes_labels,
|
| 210 |
+
pandora_available=self.args.pandora,
|
| 211 |
+
)
|
| 212 |
+
self.df_showers_pandora.append(df_batch_pandora)
|
| 213 |
+
self.df_showers_db.append(df_batch1)
|
| 214 |
+
del model_output
|
| 215 |
+
|
| 216 |
+
def create_paths(self):
|
| 217 |
+
show_df_eval_path = os.path.join(self.args.model_prefix, "showers_df_evaluation")
|
| 218 |
+
self.show_df_eval_path = show_df_eval_path
|
| 219 |
+
|
| 220 |
+
def on_train_epoch_end(self):
|
| 221 |
+
self.log("train_loss_epoch", self.loss_final / self.number_b)
|
| 222 |
+
|
| 223 |
+
def on_train_epoch_start(self):
|
| 224 |
+
self.loss_final = 0
|
| 225 |
+
self.number_b = 0
|
| 226 |
+
self.make_mom_zero()
|
| 227 |
+
if self.current_epoch == 0:
|
| 228 |
+
self.stats = {}
|
| 229 |
+
self.stats["counts"] = {}
|
| 230 |
+
self.stats["counts_pid_neutral"] = {}
|
| 231 |
+
self.stats["counts_pid_charged"] = {}
|
| 232 |
+
|
| 233 |
+
def on_validation_epoch_start(self):
|
| 234 |
+
self.total_number_events = 0
|
| 235 |
+
self.make_mom_zero()
|
| 236 |
+
self.df_showers = []
|
| 237 |
+
self.df_showers_pandora = []
|
| 238 |
+
self.df_showers_db = []
|
| 239 |
+
self.validation_step_outputs = []
|
| 240 |
+
|
| 241 |
+
def make_mom_zero(self):
|
| 242 |
+
if self.current_epoch > 1 or self.args.predict:
|
| 243 |
+
print("making momentum 0")
|
| 244 |
+
self.ScaledGooeyBatchNorm2_1.momentum = 0
|
| 245 |
+
|
| 246 |
+
def on_validation_epoch_end(self):
|
| 247 |
+
if self.trainer.is_global_zero:
|
| 248 |
+
if self.args.predict:
|
| 249 |
+
from src.layers.inference_oc import store_at_batch_end
|
| 250 |
+
import pandas as pd
|
| 251 |
+
|
| 252 |
+
if self.args.pandora:
|
| 253 |
+
self.df_showers_pandora = pd.concat(self.df_showers_pandora)
|
| 254 |
+
else:
|
| 255 |
+
self.df_showers_pandora = []
|
| 256 |
+
self.df_showers_db = pd.concat(self.df_showers_db)
|
| 257 |
+
store_at_batch_end(
|
| 258 |
+
path_save=os.path.join(
|
| 259 |
+
self.args.model_prefix, "showers_df_evaluation"
|
| 260 |
+
) + "/" + self.args.name_output,
|
| 261 |
+
df_batch_pandora=self.df_showers_pandora,
|
| 262 |
+
df_batch1=self.df_showers_db,
|
| 263 |
+
step=0,
|
| 264 |
+
predict=True,
|
| 265 |
+
store=True,
|
| 266 |
+
pandora_available=self.args.pandora
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
self.validation_step_outputs = []
|
| 270 |
+
self.df_showers = []
|
| 271 |
+
self.df_showers_pandora = []
|
| 272 |
+
self.df_showers_db = []
|
| 273 |
+
|
| 274 |
+
def configure_optimizers(self):
|
| 275 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.args.start_lr)
|
| 276 |
+
scheduler = CosineAnnealingThenFixedScheduler(optimizer, T_max=int(36400 * 2), fixed_lr=1e-5)
|
| 277 |
+
self.scheduler = scheduler
|
| 278 |
+
return {
|
| 279 |
+
"optimizer": optimizer,
|
| 280 |
+
"lr_scheduler": {
|
| 281 |
+
"scheduler": scheduler,
|
| 282 |
+
"interval": "step",
|
| 283 |
+
"monitor": "train_loss_epoch",
|
| 284 |
+
"frequency": 1,
|
| 285 |
+
},
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
def lr_scheduler_step(self, scheduler, optimizer_idx, metric=None):
|
| 289 |
+
scheduler.step()
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def obtain_batch_numbers(g):
|
| 293 |
+
graphs_eval = dgl.unbatch(g)
|
| 294 |
+
number_graphs = len(graphs_eval)
|
| 295 |
+
batch_numbers = []
|
| 296 |
+
for index in range(number_graphs):
|
| 297 |
+
num_nodes = graphs_eval[index].number_of_nodes()
|
| 298 |
+
batch_numbers.append(index * torch.ones(num_nodes))
|
| 299 |
+
return torch.cat(batch_numbers, dim=0)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class CosineAnnealingThenFixedScheduler:
|
| 303 |
+
def __init__(self, optimizer, T_max, fixed_lr):
|
| 304 |
+
self.cosine_scheduler = CosineAnnealingLR(optimizer, T_max=T_max, eta_min=fixed_lr)
|
| 305 |
+
self.fixed_lr = 1e-6
|
| 306 |
+
self.T_max = T_max
|
| 307 |
+
self.step_count = 0
|
| 308 |
+
self.optimizer = optimizer
|
| 309 |
+
|
| 310 |
+
def step(self):
|
| 311 |
+
if self.step_count < self.T_max:
|
| 312 |
+
self.cosine_scheduler.step()
|
| 313 |
+
else:
|
| 314 |
+
for param_group in self.optimizer.param_groups:
|
| 315 |
+
param_group["lr"] = self.fixed_lr
|
| 316 |
+
self.step_count += 1
|
| 317 |
+
|
| 318 |
+
def get_last_lr(self):
|
| 319 |
+
if self.step_count < self.T_max:
|
| 320 |
+
return self.cosine_scheduler.get_last_lr()
|
| 321 |
+
else:
|
| 322 |
+
return [self.fixed_lr for _ in self.optimizer.param_groups]
|
| 323 |
+
|
| 324 |
+
def state_dict(self):
|
| 325 |
+
return {
|
| 326 |
+
"step_count": self.step_count,
|
| 327 |
+
"cosine_scheduler_state": self.cosine_scheduler.state_dict(),
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
def load_state_dict(self, state_dict):
|
| 331 |
+
self.step_count = state_dict["step_count"]
|
| 332 |
+
self.cosine_scheduler.load_state_dict(state_dict["cosine_scheduler_state"])
|
src/models/energy_correction_NN.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PID + energy correction module.
|
| 3 |
+
The model is called after object condensation clustering to correct
|
| 4 |
+
reconstructed energies and predict particle IDs.
|
| 5 |
+
"""
|
| 6 |
+
import numpy as np
|
| 7 |
+
import wandb
|
| 8 |
+
import torch
|
| 9 |
+
from torch.nn import CrossEntropyLoss
|
| 10 |
+
from torch_scatter import scatter_add, scatter_mean
|
| 11 |
+
from typing import NamedTuple, Any
|
| 12 |
+
|
| 13 |
+
from src.layers.utils_training import obtain_clustering_for_matched_showers
|
| 14 |
+
from src.utils.post_clustering_features import (
|
| 15 |
+
get_post_clustering_features, get_extra_features, calculate_eta, calculate_phi,
|
| 16 |
+
)
|
| 17 |
+
from src.utils.pid_conversion import pid_conversion_dict
|
| 18 |
+
from src.layers.regression.loss_regression import obtain_PID_charged, obtain_PID_neutral
|
| 19 |
+
from src.models.energy_correction_charged import ChargedEnergyCorrection
|
| 20 |
+
from src.models.energy_correction_neutral import (
|
| 21 |
+
NeutralEnergyCorrection, criterion_E_cor, correct_mask_neutral,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class _ClusteringOutput(NamedTuple):
|
| 26 |
+
"""Structured return type for clustering_and_global_features."""
|
| 27 |
+
graphs: Any # batched DGL graph (feature-augmented)
|
| 28 |
+
batch_idx: torch.Tensor
|
| 29 |
+
high_level_feats: torch.Tensor # per-shower aggregate features
|
| 30 |
+
charged_idx: torch.Tensor
|
| 31 |
+
neutral_idx: torch.Tensor
|
| 32 |
+
feats_charged: torch.Tensor # NaN-zeroed high_level_feats[charged_idx]
|
| 33 |
+
feats_neutral: torch.Tensor # NaN-zeroed high_level_feats[neutral_idx]
|
| 34 |
+
pred_energy: torch.Tensor # ones placeholder, filled by forward_correction
|
| 35 |
+
pred_pos: torch.Tensor
|
| 36 |
+
pred_pid: torch.Tensor
|
| 37 |
+
true: Any
|
| 38 |
+
true_pid: torch.Tensor
|
| 39 |
+
true_coords: torch.Tensor
|
| 40 |
+
sum_e: torch.Tensor
|
| 41 |
+
e_true_daughters: torch.Tensor
|
| 42 |
+
n_fakes: int
|
| 43 |
+
extra_features: torch.Tensor
|
| 44 |
+
fakes_idx: torch.Tensor
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _zero_nans(t: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
out = t.clone()
|
| 49 |
+
out[out != out] = 0
|
| 50 |
+
return out
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _decode_pid(pred_pid: torch.Tensor, pids: list, logits: torch.Tensor, idx: torch.Tensor) -> None:
|
| 54 |
+
if pids and len(idx):
|
| 55 |
+
labels = np.array(pids)[np.argmax(logits.cpu().detach(), axis=1)]
|
| 56 |
+
pred_pid[idx.flatten()] = torch.tensor(labels).long().to(idx.device)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class EnergyCorrection:
|
| 60 |
+
def __init__(self, main_model):
|
| 61 |
+
self.args = main_model.args
|
| 62 |
+
self.get_PID_categories()
|
| 63 |
+
self.get_energy_correction()
|
| 64 |
+
self.pid_conversion_dict = pid_conversion_dict
|
| 65 |
+
self.main_model = main_model
|
| 66 |
+
self.global_step = 0
|
| 67 |
+
|
| 68 |
+
def get_PID_categories(self):
|
| 69 |
+
self.pids_neutral = [2, 3]
|
| 70 |
+
self.pids_charged = [0, 1, 4]
|
| 71 |
+
|
| 72 |
+
def get_energy_correction(self):
|
| 73 |
+
self.model_charged = ChargedEnergyCorrection(args=self.args)
|
| 74 |
+
self.model_neutral = NeutralEnergyCorrection(args=self.args)
|
| 75 |
+
|
| 76 |
+
def clustering_and_global_features(self, g, x, y, add_fakes=True) -> _ClusteringOutput:
|
| 77 |
+
(
|
| 78 |
+
graphs_new, true_new, sum_e, true_pid,
|
| 79 |
+
e_true_corr_daughters, true_coords, number_of_fakes, fakes_idx,
|
| 80 |
+
) = obtain_clustering_for_matched_showers(
|
| 81 |
+
g, x, y, self.main_model.trainer.global_rank,
|
| 82 |
+
use_gt_clusters=self.args.use_gt_clusters,
|
| 83 |
+
add_fakes=add_fakes,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
batch_num_nodes = graphs_new.batch_num_nodes()
|
| 87 |
+
batch_idx = []
|
| 88 |
+
for i, n in enumerate(batch_num_nodes):
|
| 89 |
+
batch_idx.extend([i] * n)
|
| 90 |
+
batch_idx = torch.tensor(batch_idx).to(self.main_model.device)
|
| 91 |
+
|
| 92 |
+
graphs_new.ndata["h"][:, 0:3] = graphs_new.ndata["h"][:, 0:3] / 3300
|
| 93 |
+
graphs_sum_features = scatter_add(graphs_new.ndata["h"], batch_idx, dim=0)
|
| 94 |
+
graphs_sum_features = graphs_sum_features[batch_idx]
|
| 95 |
+
betas = torch.sigmoid(graphs_new.ndata["h"][:, -1])
|
| 96 |
+
graphs_new.ndata["h"] = torch.cat(
|
| 97 |
+
(graphs_new.ndata["h"], graphs_sum_features), dim=1
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
high_level = get_post_clustering_features(graphs_new, sum_e)
|
| 101 |
+
extra_features = get_extra_features(graphs_new, betas)
|
| 102 |
+
|
| 103 |
+
dev = graphs_new.ndata["h"].device
|
| 104 |
+
n = high_level.shape[0]
|
| 105 |
+
pred_energy = torch.ones(n, device=dev)
|
| 106 |
+
pred_pos = torch.ones(n, 3, device=dev)
|
| 107 |
+
pred_pid = torch.ones(n, device=dev).long()
|
| 108 |
+
|
| 109 |
+
node_features_avg = scatter_mean(graphs_new.ndata["h"], batch_idx, dim=0)[:, 0:3]
|
| 110 |
+
eta = calculate_eta(node_features_avg[:, 0], node_features_avg[:, 1], node_features_avg[:, 2])
|
| 111 |
+
phi = calculate_phi(node_features_avg[:, 0], node_features_avg[:, 1])
|
| 112 |
+
high_level = torch.cat(
|
| 113 |
+
(high_level, node_features_avg, eta.view(-1, 1), phi.view(-1, 1)), dim=1
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
num_tracks = high_level[:, 7]
|
| 117 |
+
charged_idx = torch.where(num_tracks >= 1)[0]
|
| 118 |
+
neutral_idx = torch.where(num_tracks < 1)[0]
|
| 119 |
+
assert len(charged_idx) + len(neutral_idx) == len(num_tracks)
|
| 120 |
+
assert high_level.shape[0] == graphs_new.batch_num_nodes().shape[0]
|
| 121 |
+
|
| 122 |
+
return _ClusteringOutput(
|
| 123 |
+
graphs=graphs_new,
|
| 124 |
+
batch_idx=batch_idx,
|
| 125 |
+
high_level_feats=high_level,
|
| 126 |
+
charged_idx=charged_idx,
|
| 127 |
+
neutral_idx=neutral_idx,
|
| 128 |
+
feats_charged=_zero_nans(high_level[charged_idx]),
|
| 129 |
+
feats_neutral=_zero_nans(high_level[neutral_idx]),
|
| 130 |
+
pred_energy=pred_energy,
|
| 131 |
+
pred_pos=pred_pos,
|
| 132 |
+
pred_pid=pred_pid,
|
| 133 |
+
true=true_new,
|
| 134 |
+
true_pid=true_pid,
|
| 135 |
+
true_coords=true_coords,
|
| 136 |
+
sum_e=sum_e,
|
| 137 |
+
e_true_daughters=e_true_corr_daughters,
|
| 138 |
+
n_fakes=number_of_fakes,
|
| 139 |
+
extra_features=extra_features,
|
| 140 |
+
fakes_idx=fakes_idx,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def forward_correction(self, g, x, y, return_train):
|
| 144 |
+
cf = self.clustering_and_global_features(g, x, y, add_fakes=self.args.predict)
|
| 145 |
+
|
| 146 |
+
charged_energies = self.model_charged.charged_prediction(
|
| 147 |
+
cf.graphs, cf.charged_idx, cf.feats_charged
|
| 148 |
+
)
|
| 149 |
+
neutral_energies, neutral_pxyz_avg = self.model_neutral.neutral_prediction(
|
| 150 |
+
cf.graphs, cf.neutral_idx, cf.feats_neutral
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if len(self.pids_charged):
|
| 154 |
+
charged_energies, charged_positions, charged_PID_pred, charged_ref_pt_pred = charged_energies
|
| 155 |
+
else:
|
| 156 |
+
charged_energies, charged_positions, _ = charged_energies
|
| 157 |
+
if len(self.pids_neutral):
|
| 158 |
+
neutral_energies, neutral_positions, neutral_PID_pred, neutral_ref_pt_pred = neutral_energies
|
| 159 |
+
else:
|
| 160 |
+
neutral_energies, neutral_positions, _ = neutral_energies
|
| 161 |
+
|
| 162 |
+
cf.pred_energy[cf.charged_idx.flatten()] = charged_energies
|
| 163 |
+
cf.pred_energy[cf.neutral_idx.flatten()] = neutral_energies
|
| 164 |
+
|
| 165 |
+
_decode_pid(cf.pred_pid, self.pids_charged, charged_PID_pred, cf.charged_idx)
|
| 166 |
+
_decode_pid(cf.pred_pid, self.pids_neutral, neutral_PID_pred, cf.neutral_idx)
|
| 167 |
+
|
| 168 |
+
cf.pred_energy[cf.pred_energy < 0] = 0.0
|
| 169 |
+
|
| 170 |
+
pred_ref_pt = torch.ones_like(cf.pred_pos)
|
| 171 |
+
if len(cf.charged_idx):
|
| 172 |
+
pred_ref_pt[cf.charged_idx.flatten()] = charged_ref_pt_pred.to(pred_ref_pt.device)
|
| 173 |
+
cf.pred_pos[cf.charged_idx.flatten()] = charged_positions.float().to(cf.pred_pos.device)
|
| 174 |
+
if len(cf.neutral_idx):
|
| 175 |
+
pred_ref_pt[cf.neutral_idx.flatten()] = neutral_ref_pt_pred.to(cf.neutral_idx.device)
|
| 176 |
+
cf.pred_pos[cf.neutral_idx.flatten()] = neutral_positions.to(cf.neutral_idx.device).float()
|
| 177 |
+
|
| 178 |
+
predictions = {
|
| 179 |
+
"pred_energy_corr": cf.pred_energy,
|
| 180 |
+
"pred_pos": cf.pred_pos,
|
| 181 |
+
"neutrals_idx": cf.neutral_idx.flatten(),
|
| 182 |
+
"charged_idx": cf.charged_idx.flatten(),
|
| 183 |
+
"pred_ref_pt": pred_ref_pt,
|
| 184 |
+
"extra_features": cf.extra_features,
|
| 185 |
+
"fakes_labels": cf.fakes_idx,
|
| 186 |
+
}
|
| 187 |
+
if len(self.pids_charged) or len(self.pids_neutral):
|
| 188 |
+
predictions["pred_PID"] = cf.pred_pid
|
| 189 |
+
predictions["charged_PID_pred"] = charged_PID_pred
|
| 190 |
+
predictions["neutral_PID_pred"] = neutral_PID_pred
|
| 191 |
+
|
| 192 |
+
if return_train:
|
| 193 |
+
return x, predictions, cf.true, cf.sum_e, cf.true_pid, cf.true, cf.true_coords, cf.n_fakes
|
| 194 |
+
else:
|
| 195 |
+
return (
|
| 196 |
+
x, predictions, cf.true, cf.sum_e, cf.graphs, cf.batch_idx,
|
| 197 |
+
cf.high_level_feats, cf.true_pid, cf.e_true_daughters,
|
| 198 |
+
cf.true_coords, cf.n_fakes,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def get_loss(self, batch_g, y, result, stats, fixed):
|
| 202 |
+
(
|
| 203 |
+
model_output, dic_e_cor, e_true, e_sum_hits, new_graphs, batch_id,
|
| 204 |
+
graph_level_features, pid_true_matched, e_true_corr_daughters,
|
| 205 |
+
part_coords_matched, num_fakes,
|
| 206 |
+
) = result
|
| 207 |
+
|
| 208 |
+
e_cor = dic_e_cor["pred_energy_corr"]
|
| 209 |
+
mask_neutral_for_loss = correct_mask_neutral(
|
| 210 |
+
torch.tensor(pid_true_matched), dic_e_cor["neutrals_idx"]
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
e_true_neutrals = e_true[mask_neutral_for_loss]
|
| 214 |
+
e_pred_neutrals = e_cor[mask_neutral_for_loss]
|
| 215 |
+
e_reco_neutrals = e_sum_hits[mask_neutral_for_loss]
|
| 216 |
+
in_distribution = (torch.abs(e_true_neutrals - e_reco_neutrals) / e_true_neutrals) < 0.6
|
| 217 |
+
ypred = e_pred_neutrals[in_distribution]
|
| 218 |
+
ybatch = e_true_neutrals[in_distribution]
|
| 219 |
+
|
| 220 |
+
loss_EC_neutrals = criterion_E_cor(ypred.flatten(), ybatch.flatten()) if len(ypred) > 0 else 0
|
| 221 |
+
wandb.log({"loss_EC_neutrals": loss_EC_neutrals})
|
| 222 |
+
|
| 223 |
+
loss_neutral_pid = 0
|
| 224 |
+
loss_charged_pid = 0
|
| 225 |
+
|
| 226 |
+
if len(self.pids_charged):
|
| 227 |
+
charged_PID_pred, charged_PID_true_onehot, mask_charged = obtain_PID_charged(
|
| 228 |
+
dic_e_cor, pid_true_matched, self.pids_charged, self.args, self.pid_conversion_dict
|
| 229 |
+
)
|
| 230 |
+
loss_charged_pid, acc_charged = pid_loss(
|
| 231 |
+
charged_PID_pred, charged_PID_true_onehot,
|
| 232 |
+
e_true[dic_e_cor["charged_idx"]], mask_charged, fixed, "charged",
|
| 233 |
+
)
|
| 234 |
+
wandb.log({"loss_charged_pid": loss_charged_pid})
|
| 235 |
+
|
| 236 |
+
if len(self.pids_neutral):
|
| 237 |
+
neutral_PID_pred, neutral_PID_true_onehot, mask_neutral = obtain_PID_neutral(
|
| 238 |
+
dic_e_cor, pid_true_matched, self.pids_neutral, self.args, self.pid_conversion_dict
|
| 239 |
+
)
|
| 240 |
+
loss_neutral_pid, acc_neutral = pid_loss(
|
| 241 |
+
neutral_PID_pred, neutral_PID_true_onehot,
|
| 242 |
+
e_true, mask_neutral, fixed, "neutral",
|
| 243 |
+
)
|
| 244 |
+
wandb.log({"loss_neutral_pid": loss_neutral_pid})
|
| 245 |
+
|
| 246 |
+
return loss_EC_neutrals, 0, loss_neutral_pid, loss_charged_pid
|
| 247 |
+
|
| 248 |
+
def get_validation_step_outputs(self, batch_g, y, result):
|
| 249 |
+
(
|
| 250 |
+
model_output, e_cor, e_true, e_sum_hits,
|
| 251 |
+
new_graphs, batch_id, graph_level_features,
|
| 252 |
+
pid_true_matched, e_true_corr_daughters,
|
| 253 |
+
coords_true, num_fakes,
|
| 254 |
+
) = result
|
| 255 |
+
|
| 256 |
+
if len(self.pids_charged):
|
| 257 |
+
charged_idx = e_cor["charged_idx"]
|
| 258 |
+
if len(self.pids_neutral):
|
| 259 |
+
neutral_idx = e_cor["neutrals_idx"]
|
| 260 |
+
pred_pid = e_cor["pred_PID"]
|
| 261 |
+
charged_PID_pred = e_cor["charged_PID_pred"]
|
| 262 |
+
neutral_PID_pred = e_cor["neutral_PID_pred"]
|
| 263 |
+
pred_pos = e_cor["pred_pos"]
|
| 264 |
+
pred_ref_pt = e_cor["pred_ref_pt"]
|
| 265 |
+
extra_features = e_cor["extra_features"]
|
| 266 |
+
fakes_labels = e_cor["fakes_labels"]
|
| 267 |
+
e_cor = e_cor["pred_energy_corr"]
|
| 268 |
+
|
| 269 |
+
PID_logits = torch.zeros(len(e_cor), len(self.pids_charged) + len(self.pids_neutral)).float()
|
| 270 |
+
PID_logits[charged_idx.cpu(), 0] = charged_PID_pred.detach().cpu()[:, 0]
|
| 271 |
+
PID_logits[charged_idx.cpu(), 1] = charged_PID_pred.detach().cpu()[:, 1]
|
| 272 |
+
PID_logits[charged_idx.cpu(), 4] = charged_PID_pred.detach().cpu()[:, 2]
|
| 273 |
+
PID_logits[neutral_idx.cpu(), 2] = neutral_PID_pred.detach().cpu()[:, 0]
|
| 274 |
+
PID_logits[neutral_idx.cpu(), 3] = neutral_PID_pred.detach().cpu()[:, 1]
|
| 275 |
+
|
| 276 |
+
extra_features = extra_features.detach().cpu()
|
| 277 |
+
extra_features = torch.cat((extra_features, PID_logits), dim=1).numpy()
|
| 278 |
+
|
| 279 |
+
return e_cor, pred_pos, pred_ref_pt, pred_pid, num_fakes, extra_features, fakes_labels
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def pid_loss(
|
| 283 |
+
pid_pred_all: torch.Tensor,
|
| 284 |
+
pid_true_all: torch.Tensor,
|
| 285 |
+
e_true: torch.Tensor,
|
| 286 |
+
mask: torch.Tensor,
|
| 287 |
+
frozen: bool = False,
|
| 288 |
+
name: str = "",
|
| 289 |
+
) -> tuple:
|
| 290 |
+
if not len(pid_pred_all):
|
| 291 |
+
return 0, 0
|
| 292 |
+
mask = mask.bool()
|
| 293 |
+
pid_pred = pid_pred_all[mask]
|
| 294 |
+
pid_true = pid_true_all[mask]
|
| 295 |
+
if not len(pid_pred):
|
| 296 |
+
return 0, 0
|
| 297 |
+
acc = torch.sum(pid_pred == pid_true) / len(pid_pred)
|
| 298 |
+
loss = CrossEntropyLoss()(pid_pred, pid_true)
|
| 299 |
+
return loss, acc
|
src/models/energy_correction_charged.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
energy_correction_charged.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch_scatter import scatter_sum
|
| 8 |
+
from xformers.ops.fmha import BlockDiagonalMask
|
| 9 |
+
import dgl
|
| 10 |
+
|
| 11 |
+
from gatr import GATr, SelfAttentionConfig, MLPConfig
|
| 12 |
+
from gatr.interface import embed_point, embed_scalar
|
| 13 |
+
from src.layers.tools_for_regression import PickPAtDCA
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ChargedEnergyCorrection(nn.Module):
|
| 17 |
+
def __init__(self, args):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.in_features_global = 16
|
| 20 |
+
self.in_features_gnn = 16 # GATr multivector output dim per batch
|
| 21 |
+
self.pid_channels = [0, 1, 4]
|
| 22 |
+
n_layers = 3
|
| 23 |
+
self.args = args
|
| 24 |
+
|
| 25 |
+
self.gatr = GATr(
|
| 26 |
+
in_mv_channels=1,
|
| 27 |
+
out_mv_channels=1,
|
| 28 |
+
hidden_mv_channels=4,
|
| 29 |
+
in_s_channels=2,
|
| 30 |
+
out_s_channels=None,
|
| 31 |
+
hidden_s_channels=4,
|
| 32 |
+
num_blocks=3,
|
| 33 |
+
attention=SelfAttentionConfig(),
|
| 34 |
+
mlp=MLPConfig(),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
out_features_gnn = self.in_features_gnn
|
| 38 |
+
in_features_global = self.in_features_global
|
| 39 |
+
n_pid_classes = len(self.pid_channels)
|
| 40 |
+
|
| 41 |
+
pid_layers = [nn.Linear(out_features_gnn + in_features_global + 1, 64)]
|
| 42 |
+
for _ in range(n_layers - 1):
|
| 43 |
+
pid_layers.append(nn.Linear(64, 64))
|
| 44 |
+
pid_layers.append(nn.ReLU())
|
| 45 |
+
pid_layers.append(nn.Linear(64, n_pid_classes))
|
| 46 |
+
self.PID_head = nn.Sequential(*pid_layers)
|
| 47 |
+
|
| 48 |
+
self.PickPAtDCA = PickPAtDCA()
|
| 49 |
+
|
| 50 |
+
def charged_prediction(self, graphs_new, charged_idx, graphs_high_level_features):
|
| 51 |
+
unbatched = dgl.unbatch(graphs_new)
|
| 52 |
+
if len(charged_idx) > 0:
|
| 53 |
+
charged_graphs = dgl.batch([unbatched[i] for i in charged_idx])
|
| 54 |
+
charged_energies = self.predict(
|
| 55 |
+
graphs_high_level_features,
|
| 56 |
+
charged_graphs,
|
| 57 |
+
|
| 58 |
+
)
|
| 59 |
+
else:
|
| 60 |
+
empty = torch.tensor([]).to(graphs_new.ndata["h"].device)
|
| 61 |
+
charged_energies = [empty, empty, empty, empty]
|
| 62 |
+
return charged_energies
|
| 63 |
+
|
| 64 |
+
def predict(self, x_global_features, graphs_new=None):
|
| 65 |
+
"""
|
| 66 |
+
Forward pass for charged energy correction.
|
| 67 |
+
:param x_global_features: Global graph-level features (batch, in_features_global)
|
| 68 |
+
:param graphs_new: Batched DGL graph of hit-level data
|
| 69 |
+
:return: (E, direction, pid_pred, ref_pt_pred)
|
| 70 |
+
"""
|
| 71 |
+
if graphs_new is not None:
|
| 72 |
+
batch_num_nodes = graphs_new.batch_num_nodes()
|
| 73 |
+
batch_idx = []
|
| 74 |
+
for i, n in enumerate(batch_num_nodes):
|
| 75 |
+
batch_idx.extend([i] * n)
|
| 76 |
+
batch_idx = torch.tensor(batch_idx).to(graphs_new.device)
|
| 77 |
+
|
| 78 |
+
hits_points = graphs_new.ndata["h"][:, 0:3]
|
| 79 |
+
hit_type = graphs_new.ndata["h"][:, 4:8].argmax(dim=1)
|
| 80 |
+
p = graphs_new.ndata["h"][:, 9]
|
| 81 |
+
e = graphs_new.ndata["h"][:, 8]
|
| 82 |
+
|
| 83 |
+
embedded_inputs = embed_point(hits_points) + embed_scalar(hit_type.view(-1, 1))
|
| 84 |
+
extra_scalars = torch.cat([p.unsqueeze(1), e.unsqueeze(1)], dim=1)
|
| 85 |
+
mask = self.build_attention_mask(graphs_new)
|
| 86 |
+
embedded_inputs = embedded_inputs.unsqueeze(-2)
|
| 87 |
+
|
| 88 |
+
embedded_outputs, _ = self.gatr(
|
| 89 |
+
embedded_inputs, scalars=extra_scalars, attention_mask=mask
|
| 90 |
+
)
|
| 91 |
+
embedded_outputs_per_batch = scatter_sum(embedded_outputs[:, 0, :], batch_idx, dim=0)
|
| 92 |
+
|
| 93 |
+
recovered_E = x_global_features[:, 6] / x_global_features[:, 3]
|
| 94 |
+
x_global_features = torch.cat((x_global_features, recovered_E.view(-1, 1)), dim=1)
|
| 95 |
+
model_x = torch.cat([x_global_features, embedded_outputs_per_batch], dim=1)
|
| 96 |
+
|
| 97 |
+
pid_pred = self.PID_head(model_x)
|
| 98 |
+
p_tracks, pos, ref_pt_pred = self.PickPAtDCA.predict(x_global_features, graphs_new)
|
| 99 |
+
E = torch.norm(pos, dim=1)
|
| 100 |
+
pos = (pos / torch.norm(pos, dim=1).unsqueeze(1)).clone()
|
| 101 |
+
return E, pos, pid_pred, ref_pt_pred
|
| 102 |
+
|
| 103 |
+
@staticmethod
|
| 104 |
+
def obtain_batch_numbers(g):
|
| 105 |
+
graphs_eval = dgl.unbatch(g)
|
| 106 |
+
batch_numbers = []
|
| 107 |
+
for index, gj in enumerate(graphs_eval):
|
| 108 |
+
num_nodes = gj.number_of_nodes()
|
| 109 |
+
batch_numbers.append(index * torch.ones(num_nodes))
|
| 110 |
+
return torch.cat(batch_numbers, dim=0)
|
| 111 |
+
|
| 112 |
+
def build_attention_mask(self, g):
|
| 113 |
+
batch_numbers = self.obtain_batch_numbers(g)
|
| 114 |
+
return BlockDiagonalMask.from_seqlens(
|
| 115 |
+
torch.bincount(batch_numbers.long()).tolist()
|
| 116 |
+
)
|
src/models/energy_correction_neutral.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
energy_correction_neutral.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch_scatter import scatter_sum
|
| 9 |
+
from xformers.ops.fmha import BlockDiagonalMask
|
| 10 |
+
import dgl
|
| 11 |
+
|
| 12 |
+
from gatr import GATr, SelfAttentionConfig, MLPConfig
|
| 13 |
+
from gatr.interface import embed_point, embed_scalar
|
| 14 |
+
from src.models.E_correction_module import Net
|
| 15 |
+
from src.layers.tools_for_regression import ECNetWrapperAvg, AverageHitsP
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class NeutralEnergyCorrection(nn.Module):
|
| 19 |
+
def __init__(self, args):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.in_features_global = 16
|
| 22 |
+
self.in_features_gnn = 16 # GATr multivector output dim per batch
|
| 23 |
+
self.pid_channels = [2, 3]
|
| 24 |
+
self.args = args
|
| 25 |
+
n_layers = 3
|
| 26 |
+
|
| 27 |
+
gatr_kwargs = dict(
|
| 28 |
+
in_mv_channels=1,
|
| 29 |
+
out_mv_channels=1,
|
| 30 |
+
hidden_mv_channels=4,
|
| 31 |
+
in_s_channels=2,
|
| 32 |
+
out_s_channels=None,
|
| 33 |
+
hidden_s_channels=4,
|
| 34 |
+
num_blocks=3,
|
| 35 |
+
attention=SelfAttentionConfig(),
|
| 36 |
+
mlp=MLPConfig(),
|
| 37 |
+
)
|
| 38 |
+
self.gatr = GATr(**gatr_kwargs)
|
| 39 |
+
self.gatr_pid = GATr(**gatr_kwargs)
|
| 40 |
+
|
| 41 |
+
out_features_gnn = self.in_features_gnn
|
| 42 |
+
in_features_global = self.in_features_global
|
| 43 |
+
n_pid_classes = len(self.pid_channels)
|
| 44 |
+
out_f = 1 # Energy prediction (scalar)
|
| 45 |
+
|
| 46 |
+
pid_layers = [nn.Linear(out_features_gnn + in_features_global, 64)]
|
| 47 |
+
for _ in range(n_layers - 1):
|
| 48 |
+
pid_layers.append(nn.Linear(64, 64))
|
| 49 |
+
pid_layers.append(nn.ReLU())
|
| 50 |
+
pid_layers.append(nn.Linear(64, n_pid_classes))
|
| 51 |
+
self.PID_head = nn.Sequential(*pid_layers)
|
| 52 |
+
|
| 53 |
+
self.model = Net(
|
| 54 |
+
in_features=out_features_gnn + in_features_global,
|
| 55 |
+
out_features=out_f,
|
| 56 |
+
return_raw=True,
|
| 57 |
+
)
|
| 58 |
+
self.ec_model_wrapper_neutral_avg = ECNetWrapperAvg()
|
| 59 |
+
self.AvgHits = AverageHitsP(ecal_only=True)
|
| 60 |
+
|
| 61 |
+
def neutral_prediction(self, graphs_new, neutral_idx, features_neutral_no_nan):
|
| 62 |
+
unbatched = dgl.unbatch(graphs_new)
|
| 63 |
+
if len(neutral_idx) > 0:
|
| 64 |
+
neutral_graphs = dgl.batch([unbatched[i] for i in neutral_idx])
|
| 65 |
+
neutral_energies = self.predict(
|
| 66 |
+
features_neutral_no_nan,
|
| 67 |
+
neutral_graphs,
|
| 68 |
+
)
|
| 69 |
+
neutral_pxyz_avg = self.ec_model_wrapper_neutral_avg.predict(
|
| 70 |
+
features_neutral_no_nan,
|
| 71 |
+
neutral_graphs,
|
| 72 |
+
)[1]
|
| 73 |
+
else:
|
| 74 |
+
empty = torch.tensor([]).to(graphs_new.ndata["h"].device)
|
| 75 |
+
neutral_energies = [empty, empty, empty, empty]
|
| 76 |
+
neutral_pxyz_avg = empty
|
| 77 |
+
return neutral_energies, neutral_pxyz_avg
|
| 78 |
+
|
| 79 |
+
def predict(self, x_global_features, graphs_new=None):
|
| 80 |
+
"""
|
| 81 |
+
Forward pass for neutral energy correction.
|
| 82 |
+
:param x_global_features: Global graph-level features (batch, in_features_global)
|
| 83 |
+
:param graphs_new: Batched DGL graph of hit-level data
|
| 84 |
+
:return: (E_pred, direction, pid_pred, ref_pt_pred)
|
| 85 |
+
"""
|
| 86 |
+
if graphs_new is not None:
|
| 87 |
+
batch_num_nodes = graphs_new.batch_num_nodes()
|
| 88 |
+
batch_idx = []
|
| 89 |
+
for i, n in enumerate(batch_num_nodes):
|
| 90 |
+
batch_idx.extend([i] * n)
|
| 91 |
+
batch_idx = torch.tensor(batch_idx).to(graphs_new.device)
|
| 92 |
+
|
| 93 |
+
hits_points = graphs_new.ndata["h"][:, 0:3]
|
| 94 |
+
hit_type = graphs_new.ndata["h"][:, 4:8].argmax(dim=1)
|
| 95 |
+
p = graphs_new.ndata["h"][:, 9]
|
| 96 |
+
e = graphs_new.ndata["h"][:, 8]
|
| 97 |
+
|
| 98 |
+
embedded_inputs = embed_point(hits_points) + embed_scalar(hit_type.view(-1, 1))
|
| 99 |
+
extra_scalars = torch.cat([p.unsqueeze(1), e.unsqueeze(1)], dim=1)
|
| 100 |
+
mask = self.build_attention_mask(graphs_new)
|
| 101 |
+
embedded_inputs = embedded_inputs.unsqueeze(-2)
|
| 102 |
+
|
| 103 |
+
embedded_outputs, _ = self.gatr(
|
| 104 |
+
embedded_inputs, scalars=extra_scalars, attention_mask=mask
|
| 105 |
+
)
|
| 106 |
+
embedded_outputs_per_batch = scatter_sum(embedded_outputs[:, 0, :], batch_idx, dim=0)
|
| 107 |
+
model_x = torch.cat([x_global_features, embedded_outputs_per_batch], dim=1)
|
| 108 |
+
|
| 109 |
+
embedded_outputs_pid, _ = self.gatr_pid(
|
| 110 |
+
embedded_inputs, scalars=extra_scalars, attention_mask=mask
|
| 111 |
+
)
|
| 112 |
+
embedded_outputs_per_batch_pid = scatter_sum(
|
| 113 |
+
embedded_outputs_pid[:, 0, :], batch_idx, dim=0
|
| 114 |
+
)
|
| 115 |
+
model_x_pid = torch.cat([x_global_features, embedded_outputs_per_batch_pid], dim=1)
|
| 116 |
+
|
| 117 |
+
res = self.model(model_x)
|
| 118 |
+
pid_pred = self.PID_head(model_x_pid)
|
| 119 |
+
E_pred = res[:, 0]
|
| 120 |
+
|
| 121 |
+
_, p_pred, ref_pt_pred = self.AvgHits.predict(x_global_features, graphs_new)
|
| 122 |
+
p_pred = (p_pred / torch.norm(p_pred, dim=1).unsqueeze(1)).clone()
|
| 123 |
+
return E_pred, p_pred, pid_pred, ref_pt_pred
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
def obtain_batch_numbers(g):
|
| 127 |
+
graphs_eval = dgl.unbatch(g)
|
| 128 |
+
batch_numbers = []
|
| 129 |
+
for index, gj in enumerate(graphs_eval):
|
| 130 |
+
num_nodes = gj.number_of_nodes()
|
| 131 |
+
batch_numbers.append(index * torch.ones(num_nodes))
|
| 132 |
+
return torch.cat(batch_numbers, dim=0)
|
| 133 |
+
|
| 134 |
+
def build_attention_mask(self, g):
|
| 135 |
+
batch_numbers = self.obtain_batch_numbers(g)
|
| 136 |
+
return BlockDiagonalMask.from_seqlens(
|
| 137 |
+
torch.bincount(batch_numbers.long()).tolist()
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def correct_mask_neutral(pid_neutral, neural_mask):
|
| 142 |
+
"""
|
| 143 |
+
Filter neutral-candidate indices to keep only genuine neutral PIDs.
|
| 144 |
+
"""
|
| 145 |
+
pid_neutral = pid_neutral.to(neural_mask.device)
|
| 146 |
+
pid_neutral = torch.abs(pid_neutral)
|
| 147 |
+
keep_list = torch.tensor([22, 130, 2112], device=pid_neutral.device)
|
| 148 |
+
selected_pids = pid_neutral[neural_mask]
|
| 149 |
+
keep_mask = torch.isin(selected_pids, keep_list)
|
| 150 |
+
return neural_mask[keep_mask.to(neural_mask.device)]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def criterion_E_cor(ypred, ytrue):
|
| 154 |
+
if len(ypred) > 0:
|
| 155 |
+
return torch.mean(F.l1_loss(ypred, ytrue, reduction="none"))
|
| 156 |
+
else:
|
| 157 |
+
return 0
|
src/models/wrapper/example_mode_gatr_noise.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from src.models.Gatr_pf_e_noise import ExampleWrapper
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class GraphTransformerNetWrapper(torch.nn.Module):
|
| 6 |
+
def __init__(self, args, dev, **kwargs) -> None:
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.mod = ExampleWrapper(args, dev, **kwargs)
|
| 9 |
+
|
| 10 |
+
def forward(self, g, y, step_count, **kwargs):
|
| 11 |
+
return self.mod(g, y, step_count, **kwargs)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_model(data_config, args, dev, **kwargs):
|
| 15 |
+
model = GraphTransformerNetWrapper(args, dev, **kwargs)
|
| 16 |
+
model_info = {}
|
| 17 |
+
return model, model_info
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_loss(data_config, **kwargs):
|
| 21 |
+
return torch.nn.MSELoss()
|
src/train_lightning1.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import glob
|
| 6 |
+
import torch
|
| 7 |
+
import lightning as L
|
| 8 |
+
from lightning.pytorch.loggers import WandbLogger
|
| 9 |
+
|
| 10 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
|
| 11 |
+
|
| 12 |
+
from src.utils.parser_args import parser
|
| 13 |
+
from src.utils.train_utils import (
|
| 14 |
+
train_load,
|
| 15 |
+
test_load,
|
| 16 |
+
get_samples_steps_per_epoch,
|
| 17 |
+
model_setup,
|
| 18 |
+
set_gpus,
|
| 19 |
+
)
|
| 20 |
+
from src.utils.load_pretrained_models import (
|
| 21 |
+
load_train_model,
|
| 22 |
+
load_test_model,
|
| 23 |
+
)
|
| 24 |
+
from src.utils.callbacks import (
|
| 25 |
+
get_callbacks,
|
| 26 |
+
get_callbacks_eval,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ----------------------------------------------------------------------
|
| 31 |
+
# Helpers
|
| 32 |
+
# ----------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
def setup_wandb(args):
|
| 35 |
+
return WandbLogger(
|
| 36 |
+
project=args.wandb_projectname,
|
| 37 |
+
entity=args.wandb_entity,
|
| 38 |
+
name=args.wandb_displayname,
|
| 39 |
+
log_model="all",
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def build_trainer(args, gpus, logger, training=True):
|
| 44 |
+
callbacks = get_callbacks(args) if training else get_callbacks_eval(args)
|
| 45 |
+
|
| 46 |
+
strategy = "auto" if args.correction else "ddp" if training else None
|
| 47 |
+
|
| 48 |
+
return L.Trainer(
|
| 49 |
+
callbacks=callbacks,
|
| 50 |
+
accelerator="gpu",
|
| 51 |
+
devices=gpus,
|
| 52 |
+
default_root_dir=args.model_prefix,
|
| 53 |
+
logger=logger,
|
| 54 |
+
max_epochs=args.num_epochs if training else None,
|
| 55 |
+
strategy=strategy,
|
| 56 |
+
limit_train_batches=args.train_batches if training else None,
|
| 57 |
+
limit_val_batches=5 if training else None,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ----------------------------------------------------------------------
|
| 62 |
+
# Main
|
| 63 |
+
# ----------------------------------------------------------------------
|
| 64 |
+
|
| 65 |
+
def main():
|
| 66 |
+
args = parser.parse_args()
|
| 67 |
+
torch.autograd.set_detect_anomaly(True)
|
| 68 |
+
|
| 69 |
+
training_mode = not args.predict
|
| 70 |
+
args.local_rank = 0
|
| 71 |
+
|
| 72 |
+
# --------------------------------------------------
|
| 73 |
+
# Data
|
| 74 |
+
# --------------------------------------------------
|
| 75 |
+
args = get_samples_steps_per_epoch(args)
|
| 76 |
+
|
| 77 |
+
if training_mode:
|
| 78 |
+
args.data_train = glob.glob(args.data_train[0] + "*.parquet")
|
| 79 |
+
train_loader, val_loader, data_config, train_input_names = train_load(args)
|
| 80 |
+
else:
|
| 81 |
+
test_loaders, data_config = test_load(args)
|
| 82 |
+
|
| 83 |
+
# --------------------------------------------------
|
| 84 |
+
# Model & devices
|
| 85 |
+
# --------------------------------------------------
|
| 86 |
+
model = model_setup(args, data_config)
|
| 87 |
+
gpus, dev = set_gpus(args)
|
| 88 |
+
|
| 89 |
+
if training_mode and args.load_model_weights:
|
| 90 |
+
model = load_train_model(args, dev)
|
| 91 |
+
|
| 92 |
+
# --------------------------------------------------
|
| 93 |
+
# Logger
|
| 94 |
+
# --------------------------------------------------
|
| 95 |
+
wandb_logger = setup_wandb(args)
|
| 96 |
+
|
| 97 |
+
# --------------------------------------------------
|
| 98 |
+
# Training
|
| 99 |
+
# --------------------------------------------------
|
| 100 |
+
if training_mode:
|
| 101 |
+
trainer = build_trainer(args, gpus, wandb_logger, training=True)
|
| 102 |
+
args.local_rank = trainer.global_rank
|
| 103 |
+
|
| 104 |
+
trainer.fit(
|
| 105 |
+
model=model,
|
| 106 |
+
train_dataloaders=train_loader,
|
| 107 |
+
val_dataloaders=val_loader,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# --------------------------------------------------
|
| 111 |
+
# Evaluation
|
| 112 |
+
# --------------------------------------------------
|
| 113 |
+
if args.data_test:
|
| 114 |
+
if args.load_model_weights:
|
| 115 |
+
model = load_test_model(args, dev)
|
| 116 |
+
|
| 117 |
+
trainer = build_trainer(args, gpus, wandb_logger, training=False)
|
| 118 |
+
|
| 119 |
+
for name, get_test_loader in test_loaders.items():
|
| 120 |
+
test_loader = get_test_loader()
|
| 121 |
+
trainer.validate(
|
| 122 |
+
model=model,
|
| 123 |
+
dataloaders=test_loader,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
main()
|
src/utils/callbacks.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from lightning.pytorch.callbacks import (
|
| 3 |
+
TQDMProgressBar,
|
| 4 |
+
ModelCheckpoint,
|
| 5 |
+
LearningRateMonitor,
|
| 6 |
+
)
|
| 7 |
+
from src.layers.utils_training import FreezeClustering
|
| 8 |
+
|
| 9 |
+
def get_callbacks(args):
|
| 10 |
+
checkpoint_callback = ModelCheckpoint(
|
| 11 |
+
dirpath=args.model_prefix, # checkpoints_path, # <--- specify this on the trainer itself for version control
|
| 12 |
+
filename="_{epoch}_{step}",
|
| 13 |
+
# every_n_epochs=val_every_n_epochs,
|
| 14 |
+
every_n_train_steps=500,
|
| 15 |
+
save_top_k=-1, # <--- this is important!
|
| 16 |
+
save_weights_only=True,
|
| 17 |
+
)
|
| 18 |
+
lr_monitor = LearningRateMonitor(logging_interval="epoch")
|
| 19 |
+
callbacks = [
|
| 20 |
+
TQDMProgressBar(refresh_rate=10),
|
| 21 |
+
checkpoint_callback,
|
| 22 |
+
lr_monitor,
|
| 23 |
+
]
|
| 24 |
+
if args.freeze_clustering:
|
| 25 |
+
callbacks.append(FreezeClustering())
|
| 26 |
+
return callbacks
|
| 27 |
+
|
| 28 |
+
def get_callbacks_eval(args):
|
| 29 |
+
callbacks=[TQDMProgressBar(refresh_rate=1)]
|
| 30 |
+
return callbacks
|
src/utils/import_tools.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from importlib.util import spec_from_file_location, module_from_spec
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def import_module(path, name='_mod'):
|
| 5 |
+
spec = spec_from_file_location(name, path)
|
| 6 |
+
mod = module_from_spec(spec)
|
| 7 |
+
spec.loader.exec_module(mod)
|
| 8 |
+
return mod
|
src/utils/inference/pandas_helpers.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gzip
|
| 2 |
+
import pickle
|
| 3 |
+
import mplhep as hep
|
| 4 |
+
from src.utils.pid_conversion import pid_conversion_dict
|
| 5 |
+
|
| 6 |
+
#hep.style.use("CMS")
|
| 7 |
+
import matplotlib
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def open_mlpf_dataframe(path_mlpf, neutrals_only=False, charged_only=False):
|
| 15 |
+
data = pd.read_pickle(path_mlpf)
|
| 16 |
+
sd = data
|
| 17 |
+
sd["pid_4_class_true"] = sd["pid"].map(pid_conversion_dict)
|
| 18 |
+
if "pred_pid_matched" in sd.columns:
|
| 19 |
+
sd.loc[sd["pred_pid_matched"] < -1, "pred_pid_matched"] = np.nan
|
| 20 |
+
return sd
|
| 21 |
+
|
| 22 |
+
def concat_with_batch_fix(dfs, batch_key="number_batch"):
|
| 23 |
+
|
| 24 |
+
corrected_dfs = []
|
| 25 |
+
batch_offset = 0
|
| 26 |
+
|
| 27 |
+
for df in dfs:
|
| 28 |
+
df = df.copy()
|
| 29 |
+
if batch_key in df.columns:
|
| 30 |
+
df[batch_key] = df[batch_key] + batch_offset
|
| 31 |
+
batch_offset = df[batch_key].max() + 1
|
| 32 |
+
else:
|
| 33 |
+
raise KeyError(f"'{batch_key}' not found in one of the DataFrames.")
|
| 34 |
+
corrected_dfs.append(df)
|
| 35 |
+
return pd.concat(corrected_dfs, ignore_index=True)
|
| 36 |
+
|
src/utils/load_pretrained_models.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
def load_train_model(args, dev):
|
| 5 |
+
from src.models.Gatr_pf_e_noise import ExampleWrapper as GravnetModel
|
| 6 |
+
model = GravnetModel.load_from_checkpoint(
|
| 7 |
+
args.load_model_weights, args=args, dev=0, map_location=dev,strict=False)
|
| 8 |
+
return model
|
| 9 |
+
|
| 10 |
+
def load_test_model(args, dev):
|
| 11 |
+
if args.load_model_weights is not None and (not args.correction):
|
| 12 |
+
from src.models.Gatr_pf_e_noise import ExampleWrapper as GravnetModel
|
| 13 |
+
model = GravnetModel.load_from_checkpoint(
|
| 14 |
+
args.load_model_weights, args=args, dev=0, map_location=dev, strict=False
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
if args.load_model_weights is not None and args.correction:
|
| 18 |
+
from src.models.Gatr_pf_e_noise import ExampleWrapper as GravnetModel
|
| 19 |
+
ckpt = torch.load(args.load_model_weights, map_location=dev)
|
| 20 |
+
|
| 21 |
+
state_dict = ckpt["state_dict"]
|
| 22 |
+
model = GravnetModel( args=args, dev=0)
|
| 23 |
+
model.load_state_dict(state_dict, strict=False)
|
| 24 |
+
|
| 25 |
+
model2 = GravnetModel.load_from_checkpoint(args.load_model_weights_clustering, args=args, dev=0, strict=False, map_location=torch.device("cuda:0"))
|
| 26 |
+
model.gatr = model2.gatr
|
| 27 |
+
model.ScaledGooeyBatchNorm2_1 = model2.ScaledGooeyBatchNorm2_1
|
| 28 |
+
model.clustering = model2.clustering
|
| 29 |
+
model.beta = model2.beta
|
| 30 |
+
model.eval()
|
| 31 |
+
return model
|
| 32 |
+
|
src/utils/logger_wandb.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import wandb
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from sklearn.metrics import roc_curve, roc_auc_score
|
| 5 |
+
import json
|
| 6 |
+
import dgl
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from sklearn.decomposition import PCA
|
| 9 |
+
from torch_scatter import scatter_max
|
| 10 |
+
from matplotlib.cm import ScalarMappable
|
| 11 |
+
from matplotlib.colors import Normalize
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def log_losses_wandb(
|
| 15 |
+
logwandb, num_batches, local_rank, losses, loss, val=False
|
| 16 |
+
):
|
| 17 |
+
if val:
|
| 18 |
+
val_ = " val"
|
| 19 |
+
else:
|
| 20 |
+
val_ = ""
|
| 21 |
+
if logwandb and ((num_batches - 1) % 10) == 0 and local_rank == 0:
|
| 22 |
+
wandb.log(
|
| 23 |
+
{
|
| 24 |
+
"loss" + val_ + " regression": loss,
|
| 25 |
+
"loss" + val_ + " lv": losses[0],
|
| 26 |
+
"loss" + val_ + " beta": losses[1],
|
| 27 |
+
"loss" + val_ + " beta sig": losses[2],
|
| 28 |
+
"loss" + val_ + " beta noise": losses[3],
|
| 29 |
+
"loss" + val_ + " attractive": losses[12],
|
| 30 |
+
"loss" + val_ + " repulsive": losses[13],
|
| 31 |
+
}
|
| 32 |
+
)
|
| 33 |
+
|
src/utils/parser_args.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
parser = argparse.ArgumentParser()
|
| 4 |
+
|
| 5 |
+
parser.add_argument(
|
| 6 |
+
"--freeze-clustering",
|
| 7 |
+
action="store_true",
|
| 8 |
+
default=False,
|
| 9 |
+
help="Freeze the clustering part of the model",
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
parser.add_argument("-c", "--data-config", type=str, help="data config YAML file")
|
| 14 |
+
|
| 15 |
+
parser.add_argument(
|
| 16 |
+
"-i",
|
| 17 |
+
"--data-train",
|
| 18 |
+
nargs="*",
|
| 19 |
+
default=[],
|
| 20 |
+
help="training files; supported syntax:"
|
| 21 |
+
" (a) plain list, `--data-train /path/to/a/* /path/to/b/*`;"
|
| 22 |
+
" (b) (named) groups [Recommended], `--data-train a:/path/to/a/* b:/path/to/b/*`,"
|
| 23 |
+
" the file splitting (for each dataloader worker) will be performed per group,"
|
| 24 |
+
" and then mixed together, to ensure a uniform mixing from all groups for each worker.",
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"-l",
|
| 28 |
+
"--data-val",
|
| 29 |
+
nargs="*",
|
| 30 |
+
default=[],
|
| 31 |
+
help="validation files; when not set, will use training files and split by `--train-val-split`",
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"-t",
|
| 35 |
+
"--data-test",
|
| 36 |
+
nargs="*",
|
| 37 |
+
default=[],
|
| 38 |
+
help="testing files; supported syntax:"
|
| 39 |
+
" (a) plain list, `--data-test /path/to/a/* /path/to/b/*`;"
|
| 40 |
+
" (b) keyword-based, `--data-test a:/path/to/a/* b:/path/to/b/*`, will produce output_a, output_b;"
|
| 41 |
+
" (c) split output per N input files, `--data-test a%10:/path/to/a/*`, will split per 10 input files",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--data-fraction",
|
| 46 |
+
type=float,
|
| 47 |
+
default=1,
|
| 48 |
+
help="fraction of events to load from each file; for training, the events are randomly selected for each epoch",
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--file-fraction",
|
| 52 |
+
type=float,
|
| 53 |
+
default=1,
|
| 54 |
+
help="fraction of files to load; for training, the files are randomly selected for each epoch",
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--fetch-by-files",
|
| 58 |
+
action="store_true",
|
| 59 |
+
default=False,
|
| 60 |
+
help="When enabled, will load all events from a small number (set by ``--fetch-step``) of files for each data fetching. "
|
| 61 |
+
"Otherwise (default), load a small fraction of events from all files each time, which helps reduce variations in the sample composition.",
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--fetch-step",
|
| 65 |
+
type=float,
|
| 66 |
+
default=0.01,
|
| 67 |
+
help="fraction of events to load each time from every file (when ``--fetch-by-files`` is disabled); "
|
| 68 |
+
"Or: number of files to load each time (when ``--fetch-by-files`` is enabled). Shuffling & sampling is done within these events, so set a large enough value.",
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--train-val-split",
|
| 73 |
+
type=float,
|
| 74 |
+
default=0.8,
|
| 75 |
+
help="training/validation split fraction",
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"-n",
|
| 81 |
+
"--network-config",
|
| 82 |
+
type=str,
|
| 83 |
+
help="network architecture configuration file; the path must be relative to the current dir",
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"-m",
|
| 87 |
+
"--model-prefix",
|
| 88 |
+
type=str,
|
| 89 |
+
default="models/{auto}/networkss",
|
| 90 |
+
help="path to save or load the model; for training, this will be used as a prefix, so model snapshots "
|
| 91 |
+
"will saved to `{model_prefix}_epoch-%d_state.pt` after each epoch, and the one with the best "
|
| 92 |
+
"validation metric to `{model_prefix}_best_epoch_state.pt`; for testing, this should be the full path "
|
| 93 |
+
"including the suffix, otherwise the one with the best validation metric will be used; "
|
| 94 |
+
"for training, `{auto}` can be used as part of the path to auto-generate a name, "
|
| 95 |
+
"based on the timestamp and network configuration",
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--load-model-weights",
|
| 100 |
+
type=str,
|
| 101 |
+
default=None,
|
| 102 |
+
help="initialize model with pre-trained weights",
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--load-model-weights-clustering",
|
| 106 |
+
type=str,
|
| 107 |
+
default=None,
|
| 108 |
+
help="initialize model with pre-trained weights for clustering part of the model",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument("--start-lr", type=float, default=5e-3, help="start learning rate")
|
| 111 |
+
|
| 112 |
+
parser.add_argument("--num-epochs", type=int, default=20, help="number of epochs")
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--steps-per-epoch",
|
| 115 |
+
type=int,
|
| 116 |
+
default=None,
|
| 117 |
+
help="number of steps (iterations) per epochs; "
|
| 118 |
+
"if neither of `--steps-per-epoch` or `--samples-per-epoch` is set, each epoch will run over all loaded samples",
|
| 119 |
+
)
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
"--steps-per-epoch-val",
|
| 122 |
+
type=int,
|
| 123 |
+
default=None,
|
| 124 |
+
help="number of steps (iterations) per epochs for validation; "
|
| 125 |
+
"if neither of `--steps-per-epoch-val` or `--samples-per-epoch-val` is set, each epoch will run over all loaded samples",
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--samples-per-epoch",
|
| 129 |
+
type=int,
|
| 130 |
+
default=None,
|
| 131 |
+
help="number of samples per epochs; "
|
| 132 |
+
"if neither of `--steps-per-epoch` or `--samples-per-epoch` is set, each epoch will run over all loaded samples",
|
| 133 |
+
)
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--samples-per-epoch-val",
|
| 136 |
+
type=int,
|
| 137 |
+
default=None,
|
| 138 |
+
help="number of samples per epochs for validation; "
|
| 139 |
+
"if neither of `--steps-per-epoch-val` or `--samples-per-epoch-val` is set, each epoch will run over all loaded samples",
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument("--batch-size", type=int, default=128, help="batch size")
|
| 142 |
+
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
"--gpus",
|
| 145 |
+
type=str,
|
| 146 |
+
default="0",
|
| 147 |
+
help='device for the training/testing; to use CPU, set to empty string (""); to use multiple gpu, set it as a comma separated list, e.g., `1,2,3,4`',
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--num-workers",
|
| 152 |
+
type=int,
|
| 153 |
+
default=1,
|
| 154 |
+
help="number of threads to load the dataset; memory consumption and disk access load increases (~linearly) with this numbers",
|
| 155 |
+
)
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--prefetch-factor",
|
| 158 |
+
type=int,
|
| 159 |
+
default=1,
|
| 160 |
+
help="How many items to prefetch in the dataloaders. Should be about the same order of magnitude as batch size for optimal performance.",
|
| 161 |
+
)
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"--predict",
|
| 164 |
+
action="store_true",
|
| 165 |
+
default=False,
|
| 166 |
+
help="run prediction instead of training",
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--log-wandb", action="store_true", default=False, help="use wandb for loging"
|
| 174 |
+
)
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--wandb-displayname",
|
| 177 |
+
type=str,
|
| 178 |
+
help="give display name to wandb run, if not entered a random one is generated",
|
| 179 |
+
)
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"--wandb-projectname", type=str, help="project where the run is stored inside wandb"
|
| 182 |
+
)
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--wandb-entity", type=str, help="username or team name where you are sending runs"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
parser.add_argument(
|
| 189 |
+
"--qmin", type=float, default=0.1, help="define qmin for condensation"
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--frac_cluster_loss",
|
| 195 |
+
type=float,
|
| 196 |
+
default=0,
|
| 197 |
+
help="Fraction of total pairs to use for the clustering loss",
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--use-average-cc-pos",
|
| 206 |
+
default=0.0,
|
| 207 |
+
type=float,
|
| 208 |
+
help="push the alpha to the mean of the coordinates in the object by this value",
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
parser.add_argument(
|
| 213 |
+
"--correction",
|
| 214 |
+
action="store_true",
|
| 215 |
+
default=False,
|
| 216 |
+
help="Train correction only",
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--use-gt-clusters",
|
| 224 |
+
default=False,
|
| 225 |
+
action="store_true",
|
| 226 |
+
help="If toggled, uses ground-truth clusters instead of the predicted ones by the model. We can use this to simulate 'ideal' clustering.",
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
"--name-output",
|
| 232 |
+
type=str,
|
| 233 |
+
help="name of the dataframe stored during eval",
|
| 234 |
+
)
|
| 235 |
+
parser.add_argument(
|
| 236 |
+
"--train-batches",
|
| 237 |
+
default=100,
|
| 238 |
+
type=int,
|
| 239 |
+
help="number of train batches",
|
| 240 |
+
)
|
| 241 |
+
parser.add_argument(
|
| 242 |
+
"--pandora",
|
| 243 |
+
default=False,
|
| 244 |
+
action="store_true",
|
| 245 |
+
help="using pandora information",
|
| 246 |
+
)
|
src/utils/pid_conversion.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A global variable, so it doesn't have to be modified in 10 different places when new particles are added
|
| 2 |
+
|
| 3 |
+
pid_conversion_dict = {11: 0, -11: 0, 211: 1, -211: 1, 130: 2, -130: 2, 2112: 2, -2112: 2, 22: 3, 321: 1, -321: 1, 2212: 1, -2212: 1, 310: 2, -310: 2, 3122: 2, -3122: 2, 3212: 2, -3212: 2, 3112: 1, -3112: 1, 3222: 1, -3222: 1, 3224: 1, -3224: 1, 3312: 2, -3312: 2, 13: 4, -13: 4, 3322: 2, -3322: 2, 1000020030.0: 2, 1000010050.0: 2, 1000010048.0: 2, 3334: 1, -3334:1, 1000020032.0: 2, 1000080128.0: 2, 1000110208.0: 2, 1000040064.0: 2, 1000070144.0: 2, 1000010020.0:2, 1000010030.0:2, 1000020040.0:2}
|
| 4 |
+
|
| 5 |
+
pandora_to_our_mapping = {211: 1, -211: 1, -13: 4, 13: 4, 11: 0, -11: 0, 22: 3, 2112: 2, 130: 2, -2112: 2}
|
| 6 |
+
our_to_pandora_mapping = {0: [11, -11], 1: [211, -211,2212, -2212, 321, -321, 3222, 3112, 3224, -3112, -3224], 2: [2112, 130, 310, 3122, 3212], 3: [22], 4:[13,-13]}
|
| 7 |
+
|
src/utils/post_clustering_features.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch_scatter import scatter_sum, scatter_std
|
| 3 |
+
|
| 4 |
+
def calculate_phi(x, y, z=None):
|
| 5 |
+
return torch.arctan2(y, x)
|
| 6 |
+
|
| 7 |
+
def calculate_eta(x, y, z):
|
| 8 |
+
theta = torch.arctan2(torch.sqrt(x ** 2 + y ** 2), z)
|
| 9 |
+
return -torch.log(torch.tan(theta / 2))
|
| 10 |
+
|
| 11 |
+
def get_post_clustering_features(graphs_new, sum_e):
|
| 12 |
+
'''
|
| 13 |
+
Obtain graph-level qualitative features that can then be used to regress the energy corr. factor.
|
| 14 |
+
:param graph_batch: Output from the previous step - clustered, matched showers
|
| 15 |
+
:return:
|
| 16 |
+
'''
|
| 17 |
+
batch_num_nodes = graphs_new.batch_num_nodes() # Num. of hits in each graph
|
| 18 |
+
batch_idx = []
|
| 19 |
+
for i, n in enumerate(batch_num_nodes):
|
| 20 |
+
batch_idx.extend([i] * n)
|
| 21 |
+
batch_idx = torch.tensor(batch_idx).to(graphs_new.device)
|
| 22 |
+
e_hits = graphs_new.ndata["h"][:, 8]
|
| 23 |
+
|
| 24 |
+
muon_hits = graphs_new.ndata["h"][:, 7]
|
| 25 |
+
filter_muon = torch.where(muon_hits)[0]
|
| 26 |
+
per_graph_e_hits_muon = scatter_sum(e_hits[filter_muon], batch_idx[filter_muon], dim_size=batch_idx.max() + 1)
|
| 27 |
+
per_graph_n_hits_muon = scatter_sum((e_hits[filter_muon] > 0).type(torch.int), batch_idx[filter_muon], dim_size=batch_idx.max() + 1)
|
| 28 |
+
ecal_hits = graphs_new.ndata["h"][:, 5]
|
| 29 |
+
filter_ecal = torch.where(ecal_hits)[0]
|
| 30 |
+
hcal_hits = graphs_new.ndata["h"][:, 6]
|
| 31 |
+
filter_hcal = torch.where(hcal_hits)[0]
|
| 32 |
+
per_graph_e_hits_ecal = scatter_sum(e_hits[filter_ecal], batch_idx[filter_ecal], dim_size=batch_idx.max() + 1)
|
| 33 |
+
# similar as above but with scatter_std
|
| 34 |
+
per_graph_e_hits_ecal_dispersion = scatter_std(e_hits[filter_ecal], batch_idx[filter_ecal], dim_size=batch_idx.max() + 1) ** 2
|
| 35 |
+
per_graph_e_hits_hcal = scatter_sum(e_hits[filter_hcal], batch_idx[filter_hcal], dim_size=batch_idx.max() + 1)
|
| 36 |
+
# similar as above but with scatter_std -- !!!!! TODO: Retrain the base EC models using this definition !!!!!
|
| 37 |
+
per_graph_e_hits_hcal_dispersion = scatter_std(e_hits[filter_hcal], batch_idx[filter_hcal], dim_size=batch_idx.max() + 1) ** 2
|
| 38 |
+
# track_nodes =
|
| 39 |
+
track_p = scatter_sum(graphs_new.ndata["h"][:, 9], batch_idx)
|
| 40 |
+
chis_tracks = scatter_sum(graphs_new.ndata["chi_squared_tracks"], batch_idx)
|
| 41 |
+
num_tracks = scatter_sum((graphs_new.ndata["h"][:, 9] > 0).type(torch.int), batch_idx)
|
| 42 |
+
track_p = track_p / num_tracks
|
| 43 |
+
track_p[num_tracks == 0] = 0.
|
| 44 |
+
chis_tracks = chis_tracks / num_tracks
|
| 45 |
+
num_hits = graphs_new.batch_num_nodes()
|
| 46 |
+
# print shapes of the below things
|
| 47 |
+
|
| 48 |
+
return torch.nan_to_num(
|
| 49 |
+
torch.stack([per_graph_e_hits_ecal / sum_e,
|
| 50 |
+
per_graph_e_hits_hcal / sum_e,
|
| 51 |
+
num_hits, track_p,
|
| 52 |
+
per_graph_e_hits_ecal_dispersion,
|
| 53 |
+
per_graph_e_hits_hcal_dispersion,
|
| 54 |
+
sum_e, num_tracks, torch.clamp(chis_tracks, -5, 5),
|
| 55 |
+
per_graph_e_hits_muon,
|
| 56 |
+
per_graph_n_hits_muon
|
| 57 |
+
]).T
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_extra_features(graphs_new, betas):
|
| 63 |
+
'''
|
| 64 |
+
Obtain extra graph-level features for debugging of the fakes
|
| 65 |
+
'''
|
| 66 |
+
batch_num_nodes = graphs_new.batch_num_nodes() # Num. of hits in each graph
|
| 67 |
+
batch_idx = []
|
| 68 |
+
topk_highest_betas = []
|
| 69 |
+
for i, n in enumerate(batch_num_nodes):
|
| 70 |
+
batch_idx.extend([i] * n)
|
| 71 |
+
batch_idx = torch.tensor(batch_idx).to(graphs_new.device)
|
| 72 |
+
n_highest_betas = 1
|
| 73 |
+
for i in range(len(batch_num_nodes)):
|
| 74 |
+
betas_i = betas[batch_idx == i]
|
| 75 |
+
topk_betas = torch.topk(betas_i, n_highest_betas)
|
| 76 |
+
if len(topk_betas.values) < n_highest_betas:
|
| 77 |
+
topk_betas = torch.cat([topk_betas.values, torch.zeros(n_highest_betas - len(topk_betas.values))])
|
| 78 |
+
topk_highest_betas.append(topk_betas.values)
|
| 79 |
+
topk_highest_betas = torch.stack(topk_highest_betas)
|
| 80 |
+
# Concat with batch_num_nodes
|
| 81 |
+
features = torch.cat([batch_num_nodes.view(-1, 1), topk_highest_betas], dim=1)
|
| 82 |
+
return features
|
src/utils/train_utils.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import ast
|
| 3 |
+
import sys
|
| 4 |
+
import shutil
|
| 5 |
+
import glob
|
| 6 |
+
import functools
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
from src.dataset.dataset import SimpleIterDataset
|
| 11 |
+
from src.utils.import_tools import import_module
|
| 12 |
+
from src.dataset.functions_graph import graph_batch_func
|
| 13 |
+
|
| 14 |
+
def set_gpus(args):
|
| 15 |
+
if args.gpus:
|
| 16 |
+
gpus = [int(i) for i in args.gpus.split(",")]
|
| 17 |
+
dev = torch.device(gpus[0])
|
| 18 |
+
print("Using GPUs:", gpus)
|
| 19 |
+
else:
|
| 20 |
+
print("No GPUs flag provided - Setting GPUs to [0]")
|
| 21 |
+
gpus = [0]
|
| 22 |
+
dev = torch.device(gpus[0])
|
| 23 |
+
raise Exception("Please provide GPU number")
|
| 24 |
+
return gpus, dev
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_gpu_dev(args):
|
| 29 |
+
if args.gpus != "":
|
| 30 |
+
accelerator = "gpu"
|
| 31 |
+
devices = args.gpus
|
| 32 |
+
else:
|
| 33 |
+
accelerator = 0
|
| 34 |
+
devices = 0
|
| 35 |
+
return accelerator, devices
|
| 36 |
+
# TODO change this to use it from config file
|
| 37 |
+
|
| 38 |
+
def model_setup(args, data_config):
|
| 39 |
+
"""
|
| 40 |
+
Loads the model
|
| 41 |
+
:param args:
|
| 42 |
+
:param data_config:
|
| 43 |
+
:return: model, model_info, network_module
|
| 44 |
+
"""
|
| 45 |
+
network_module = import_module(args.network_config, name="_network_module")
|
| 46 |
+
|
| 47 |
+
if args.gpus:
|
| 48 |
+
gpus = [int(i) for i in args.gpus.split(",")] # ?
|
| 49 |
+
dev = torch.device(gpus[0])
|
| 50 |
+
print("using GPUs:", gpus)
|
| 51 |
+
else:
|
| 52 |
+
gpus = None
|
| 53 |
+
local_rank = 0
|
| 54 |
+
dev = torch.device("cpu")
|
| 55 |
+
model, model_info = network_module.get_model(
|
| 56 |
+
data_config, args=args, dev=dev
|
| 57 |
+
)
|
| 58 |
+
return model.mod
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_samples_steps_per_epoch(args):
|
| 62 |
+
if args.samples_per_epoch is not None:
|
| 63 |
+
if args.steps_per_epoch is None:
|
| 64 |
+
args.steps_per_epoch = args.samples_per_epoch // args.batch_size
|
| 65 |
+
else:
|
| 66 |
+
raise RuntimeError(
|
| 67 |
+
"Please use either `--steps-per-epoch` or `--samples-per-epoch`, but not both!"
|
| 68 |
+
)
|
| 69 |
+
if args.samples_per_epoch_val is not None:
|
| 70 |
+
if args.steps_per_epoch_val is None:
|
| 71 |
+
args.steps_per_epoch_val = args.samples_per_epoch_val // args.batch_size
|
| 72 |
+
else:
|
| 73 |
+
raise RuntimeError(
|
| 74 |
+
"Please use either `--steps-per-epoch-val` or `--samples-per-epoch-val`, but not both!"
|
| 75 |
+
)
|
| 76 |
+
if args.steps_per_epoch_val is None and args.steps_per_epoch is not None:
|
| 77 |
+
args.steps_per_epoch_val = round(
|
| 78 |
+
args.steps_per_epoch * (1 - args.train_val_split) / args.train_val_split
|
| 79 |
+
)
|
| 80 |
+
if args.steps_per_epoch_val is not None and args.steps_per_epoch_val < 0:
|
| 81 |
+
args.steps_per_epoch_val = None
|
| 82 |
+
return args
|
| 83 |
+
|
| 84 |
+
def to_filelist(args, mode="train"):
|
| 85 |
+
if mode == "train":
|
| 86 |
+
flist = args.data_train
|
| 87 |
+
elif mode == "val":
|
| 88 |
+
flist = args.data_val
|
| 89 |
+
else:
|
| 90 |
+
raise NotImplementedError("Invalid mode %s" % mode)
|
| 91 |
+
|
| 92 |
+
# keyword-based: 'a:/path/to/a b:/path/to/b'
|
| 93 |
+
file_dict = {}
|
| 94 |
+
for f in flist:
|
| 95 |
+
if ":" in f:
|
| 96 |
+
name, fp = f.split(":")
|
| 97 |
+
else:
|
| 98 |
+
name, fp = "_", f
|
| 99 |
+
files = glob.glob(fp)
|
| 100 |
+
if name in file_dict:
|
| 101 |
+
file_dict[name] += files
|
| 102 |
+
else:
|
| 103 |
+
file_dict[name] = files
|
| 104 |
+
|
| 105 |
+
# sort files
|
| 106 |
+
for name, files in file_dict.items():
|
| 107 |
+
file_dict[name] = sorted(files)
|
| 108 |
+
|
| 109 |
+
if args.local_rank is not None:
|
| 110 |
+
if mode == "train":
|
| 111 |
+
gpus_list, _ = set_gpus(args)
|
| 112 |
+
local_world_size = len(gpus_list) # int(os.environ['LOCAL_WORLD_SIZE'])
|
| 113 |
+
new_file_dict = {}
|
| 114 |
+
for name, files in file_dict.items():
|
| 115 |
+
new_files = files[args.local_rank :: local_world_size]
|
| 116 |
+
assert len(new_files) > 0
|
| 117 |
+
np.random.shuffle(new_files)
|
| 118 |
+
new_file_dict[name] = new_files
|
| 119 |
+
file_dict = new_file_dict
|
| 120 |
+
print(args.local_rank, len(file_dict["_"]))
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
filelist = sum(file_dict.values(), [])
|
| 124 |
+
assert len(filelist) == len(set(filelist))
|
| 125 |
+
return file_dict, filelist
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def train_load(args):
|
| 129 |
+
"""
|
| 130 |
+
Loads the training data.
|
| 131 |
+
:param args:
|
| 132 |
+
:return: train_loader, val_loader, data_config, train_inputs
|
| 133 |
+
"""
|
| 134 |
+
train_file_dict, train_files = to_filelist(args, "train")
|
| 135 |
+
if args.data_val:
|
| 136 |
+
val_file_dict, val_files = to_filelist(args, "val")
|
| 137 |
+
train_range = val_range = (0, 1)
|
| 138 |
+
else:
|
| 139 |
+
val_file_dict, val_files = train_file_dict, train_files
|
| 140 |
+
train_range = (0, args.train_val_split)
|
| 141 |
+
val_range = (args.train_val_split, 1)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
train_data = SimpleIterDataset(
|
| 146 |
+
train_file_dict,
|
| 147 |
+
args.data_config,
|
| 148 |
+
for_training=True,
|
| 149 |
+
extra_selection=None,
|
| 150 |
+
remake_weights=False,
|
| 151 |
+
load_range_and_fraction=(train_range, args.data_fraction),
|
| 152 |
+
file_fraction=args.file_fraction,
|
| 153 |
+
fetch_by_files=args.fetch_by_files,
|
| 154 |
+
fetch_step=args.fetch_step,
|
| 155 |
+
infinity_mode=args.steps_per_epoch is not None,
|
| 156 |
+
name="train" + ("" if args.local_rank is None else "_rank%d" % args.local_rank),
|
| 157 |
+
args_parse=args
|
| 158 |
+
)
|
| 159 |
+
val_data = SimpleIterDataset(
|
| 160 |
+
val_file_dict,
|
| 161 |
+
args.data_config,
|
| 162 |
+
for_training=True,
|
| 163 |
+
extra_selection=None,
|
| 164 |
+
load_range_and_fraction=(val_range, args.data_fraction),
|
| 165 |
+
file_fraction=args.file_fraction,
|
| 166 |
+
fetch_by_files=args.fetch_by_files,
|
| 167 |
+
fetch_step=args.fetch_step,
|
| 168 |
+
infinity_mode=args.steps_per_epoch_val is not None,
|
| 169 |
+
name="val" + ("" if args.local_rank is None else "_rank%d" % args.local_rank),
|
| 170 |
+
args_parse=args
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
collator_func = graph_batch_func
|
| 174 |
+
# train_data_arg = train_data
|
| 175 |
+
# val_data_arg = val_data
|
| 176 |
+
# if args.train_cap == 1:
|
| 177 |
+
# train_data_arg = [next(iter(train_data_arg))]
|
| 178 |
+
# if args.val_cap == 1:
|
| 179 |
+
# val_data_arg = [next(iter(val_data_arg))]
|
| 180 |
+
prefetch_factor = None
|
| 181 |
+
if args.num_workers > 0:
|
| 182 |
+
prefetch_factor = args.prefetch_factor
|
| 183 |
+
train_loader = DataLoader(
|
| 184 |
+
train_data,
|
| 185 |
+
batch_size=args.batch_size,
|
| 186 |
+
drop_last=True,
|
| 187 |
+
pin_memory=True,
|
| 188 |
+
num_workers=min(args.num_workers, int(len(train_files) * args.file_fraction)),
|
| 189 |
+
collate_fn=collator_func,
|
| 190 |
+
persistent_workers=False,
|
| 191 |
+
prefetch_factor=prefetch_factor
|
| 192 |
+
)
|
| 193 |
+
val_loader = DataLoader(
|
| 194 |
+
val_data,
|
| 195 |
+
batch_size=args.batch_size,
|
| 196 |
+
drop_last=True,
|
| 197 |
+
pin_memory=True,
|
| 198 |
+
collate_fn=collator_func,
|
| 199 |
+
num_workers=min(args.num_workers, int(len(val_files) * args.file_fraction)),
|
| 200 |
+
persistent_workers=args.num_workers > 0
|
| 201 |
+
and args.steps_per_epoch_val is not None,
|
| 202 |
+
prefetch_factor=prefetch_factor
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
data_config = 0 #train_data.config
|
| 206 |
+
train_input_names = 0 #train_data.config.input_names
|
| 207 |
+
train_label_names = 0 # train_data.config.label_names
|
| 208 |
+
|
| 209 |
+
return train_loader, val_loader, data_config, train_input_names
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def test_load(args):
|
| 213 |
+
"""
|
| 214 |
+
Loads the test data.
|
| 215 |
+
:param args:
|
| 216 |
+
:return: test_loaders, data_config
|
| 217 |
+
"""
|
| 218 |
+
# keyword-based --data-test: 'a:/path/to/a b:/path/to/b'
|
| 219 |
+
# split --data-test: 'a%10:/path/to/a/*'
|
| 220 |
+
file_dict = {}
|
| 221 |
+
split_dict = {}
|
| 222 |
+
for f in args.data_test:
|
| 223 |
+
if ":" in f:
|
| 224 |
+
name, fp = f.split(":")
|
| 225 |
+
if "%" in name:
|
| 226 |
+
name, split = name.split("%")
|
| 227 |
+
split_dict[name] = int(split)
|
| 228 |
+
else:
|
| 229 |
+
name, fp = "", f
|
| 230 |
+
files = glob.glob(fp)
|
| 231 |
+
if name in file_dict:
|
| 232 |
+
file_dict[name] += files
|
| 233 |
+
else:
|
| 234 |
+
file_dict[name] = files
|
| 235 |
+
|
| 236 |
+
# sort files
|
| 237 |
+
for name, files in file_dict.items():
|
| 238 |
+
file_dict[name] = sorted(files)
|
| 239 |
+
|
| 240 |
+
# apply splitting
|
| 241 |
+
for name, split in split_dict.items():
|
| 242 |
+
files = file_dict.pop(name)
|
| 243 |
+
for i in range((len(files) + split - 1) // split):
|
| 244 |
+
file_dict[f"{name}_{i}"] = files[i * split : (i + 1) * split]
|
| 245 |
+
|
| 246 |
+
def get_test_loader(name):
|
| 247 |
+
filelist = file_dict[name]
|
| 248 |
+
num_workers = min(args.num_workers, len(filelist))
|
| 249 |
+
test_data = SimpleIterDataset(
|
| 250 |
+
{name: filelist},
|
| 251 |
+
args.data_config,
|
| 252 |
+
for_training=False,
|
| 253 |
+
extra_selection=None,
|
| 254 |
+
load_range_and_fraction=((0, 1), args.data_fraction),
|
| 255 |
+
fetch_by_files=True,
|
| 256 |
+
fetch_step=1,
|
| 257 |
+
name="test_" + name,
|
| 258 |
+
args_parse=args
|
| 259 |
+
)
|
| 260 |
+
test_loader = DataLoader(
|
| 261 |
+
test_data,
|
| 262 |
+
num_workers=num_workers,
|
| 263 |
+
batch_size=args.batch_size,
|
| 264 |
+
drop_last=False,
|
| 265 |
+
pin_memory=True,
|
| 266 |
+
collate_fn=graph_batch_func,
|
| 267 |
+
)
|
| 268 |
+
return test_loader
|
| 269 |
+
|
| 270 |
+
test_loaders = {
|
| 271 |
+
name: functools.partial(get_test_loader, name) for name in file_dict
|
| 272 |
+
}
|
| 273 |
+
#data_config = SimpleIterDataset({}, args.data_config, for_training=False).config
|
| 274 |
+
data_config = 0
|
| 275 |
+
return test_loaders, data_config
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def count_parameters(model):
|
| 279 |
+
return sum(p.numel() for p in model.mod.parameters() if p.requires_grad)
|
| 280 |
+
|
| 281 |
+
|
tests/test_cpu_attention.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the CPU-compatible attention patch in src/inference.py."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _cpu_sdpa_under_test(q, k, v, attn_mask=None):
|
| 8 |
+
"""Standalone copy of _cpu_sdpa (the patched attention) for testing.
|
| 9 |
+
|
| 10 |
+
Mirrors the implementation in src.inference._patch_gatr_attention_for_cpu.
|
| 11 |
+
"""
|
| 12 |
+
B, H, N, D = q.shape
|
| 13 |
+
scale = float(D) ** -0.5
|
| 14 |
+
|
| 15 |
+
q2 = q.reshape(B * H, N, D)
|
| 16 |
+
k2 = k.reshape(B * H, N, D)
|
| 17 |
+
v2 = v.reshape(B * H, N, D)
|
| 18 |
+
|
| 19 |
+
attn = torch.bmm(q2 * scale, k2.transpose(1, 2))
|
| 20 |
+
|
| 21 |
+
if attn_mask is not None:
|
| 22 |
+
attn = attn.masked_fill(~attn_mask.unsqueeze(0), float("-inf"))
|
| 23 |
+
|
| 24 |
+
attn = torch.softmax(attn, dim=-1)
|
| 25 |
+
attn = attn.nan_to_num(0.0)
|
| 26 |
+
|
| 27 |
+
out = torch.bmm(attn, v2)
|
| 28 |
+
return out.reshape(B, H, N, D)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def test_cpu_sdpa_matches_reference():
|
| 32 |
+
"""The CPU SDPA must agree with PyTorch's reference implementation."""
|
| 33 |
+
torch.manual_seed(42)
|
| 34 |
+
B, H, N, D = 2, 4, 16, 32
|
| 35 |
+
q = torch.randn(B, H, N, D)
|
| 36 |
+
k = torch.randn(B, H, N, D)
|
| 37 |
+
v = torch.randn(B, H, N, D)
|
| 38 |
+
|
| 39 |
+
out_ours = _cpu_sdpa_under_test(q, k, v)
|
| 40 |
+
# PyTorch reference (no mask)
|
| 41 |
+
out_ref = F.scaled_dot_product_attention(q, k, v)
|
| 42 |
+
|
| 43 |
+
assert out_ours.shape == (B, H, N, D)
|
| 44 |
+
assert torch.allclose(out_ours, out_ref, atol=1e-5), (
|
| 45 |
+
f"Max diff: {(out_ours - out_ref).abs().max().item()}"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def test_cpu_sdpa_output_shape():
|
| 50 |
+
"""Output shape must be [B, H, N, D], matching the input convention."""
|
| 51 |
+
B, H, N, D = 1, 8, 64, 16
|
| 52 |
+
q = torch.randn(B, H, N, D)
|
| 53 |
+
k = torch.randn(B, H, N, D)
|
| 54 |
+
v = torch.randn(B, H, N, D)
|
| 55 |
+
out = _cpu_sdpa_under_test(q, k, v)
|
| 56 |
+
assert out.shape == (B, H, N, D)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def test_cpu_sdpa_single_head():
|
| 60 |
+
"""Single-head attention must work correctly."""
|
| 61 |
+
torch.manual_seed(0)
|
| 62 |
+
B, H, N, D = 1, 1, 10, 8
|
| 63 |
+
q = torch.randn(B, H, N, D)
|
| 64 |
+
k = torch.randn(B, H, N, D)
|
| 65 |
+
v = torch.randn(B, H, N, D)
|
| 66 |
+
|
| 67 |
+
out_ours = _cpu_sdpa_under_test(q, k, v)
|
| 68 |
+
out_ref = F.scaled_dot_product_attention(q, k, v)
|
| 69 |
+
|
| 70 |
+
assert torch.allclose(out_ours, out_ref, atol=1e-5)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def test_cpu_sdpa_asymmetric_heads_items():
|
| 74 |
+
"""Ensure heads and items dimensions are not confused.
|
| 75 |
+
|
| 76 |
+
When H != N, swapping them would change the tensor layout and
|
| 77 |
+
produce different (wrong) results.
|
| 78 |
+
"""
|
| 79 |
+
torch.manual_seed(123)
|
| 80 |
+
B, H, N, D = 1, 3, 7, 16 # H != N
|
| 81 |
+
q = torch.randn(B, H, N, D)
|
| 82 |
+
k = torch.randn(B, H, N, D)
|
| 83 |
+
v = torch.randn(B, H, N, D)
|
| 84 |
+
|
| 85 |
+
out_ours = _cpu_sdpa_under_test(q, k, v)
|
| 86 |
+
out_ref = F.scaled_dot_product_attention(q, k, v)
|
| 87 |
+
|
| 88 |
+
assert out_ours.shape == (B, H, N, D)
|
| 89 |
+
assert torch.allclose(out_ours, out_ref, atol=1e-5), (
|
| 90 |
+
f"Max diff: {(out_ours - out_ref).abs().max().item()}"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
test_cpu_sdpa_matches_reference()
|
| 96 |
+
test_cpu_sdpa_output_shape()
|
| 97 |
+
test_cpu_sdpa_single_head()
|
| 98 |
+
test_cpu_sdpa_asymmetric_heads_items()
|
| 99 |
+
print("All tests passed.")
|
tests/test_csv_priority.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests that CSV data takes priority over parquet when both are available.
|
| 2 |
+
|
| 3 |
+
This validates the fix for the issue where loading an event from parquet and
|
| 4 |
+
then modifying the CSV text fields (e.g. removing tracks) was ignored because
|
| 5 |
+
the code always re-loaded from the parquet file.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import ast
|
| 10 |
+
import textwrap
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _extract_source_priority_logic():
|
| 14 |
+
"""Extract and verify the input-source priority logic from app.py.
|
| 15 |
+
|
| 16 |
+
Reads the ``run_inference_ui`` function source and checks that CSV
|
| 17 |
+
is tested *before* parquet, so that user edits to the CSV text
|
| 18 |
+
fields are respected even when a parquet file path is present.
|
| 19 |
+
"""
|
| 20 |
+
app_path = os.path.join(os.path.dirname(__file__), "..", "app.py")
|
| 21 |
+
with open(app_path) as f:
|
| 22 |
+
source = f.read()
|
| 23 |
+
return source
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def test_csv_checked_before_parquet():
|
| 27 |
+
"""In run_inference_ui, the ``if use_csv`` branch must come before
|
| 28 |
+
``use_parquet`` so that CSV edits are not silently ignored."""
|
| 29 |
+
source = _extract_source_priority_logic()
|
| 30 |
+
|
| 31 |
+
# Find positions of the key branching statements
|
| 32 |
+
idx_csv = source.find("if use_csv:")
|
| 33 |
+
idx_parquet_elif = source.find("elif use_parquet:")
|
| 34 |
+
idx_parquet_if = source.find("if use_parquet:")
|
| 35 |
+
|
| 36 |
+
# "if use_csv:" must exist
|
| 37 |
+
assert idx_csv != -1, "Could not find 'if use_csv:' in app.py"
|
| 38 |
+
|
| 39 |
+
# "elif use_parquet:" must exist (parquet is the fallback)
|
| 40 |
+
assert idx_parquet_elif != -1, (
|
| 41 |
+
"Could not find 'elif use_parquet:' in app.py — parquet should be "
|
| 42 |
+
"a fallback after CSV"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# CSV check must come before the parquet fallback
|
| 46 |
+
assert idx_csv < idx_parquet_elif, (
|
| 47 |
+
"'if use_csv:' must appear before 'elif use_parquet:' so that "
|
| 48 |
+
"user CSV edits take priority over re-reading the parquet file"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# There should NOT be a standalone "if use_parquet:" that would take
|
| 52 |
+
# priority over CSV (the old buggy pattern)
|
| 53 |
+
if idx_parquet_if != -1:
|
| 54 |
+
# The only occurrence should be inside the guard for empty input
|
| 55 |
+
# (not use_parquet and not use_csv). A standalone "if use_parquet:"
|
| 56 |
+
# that dispatches to load_event_from_parquet before checking CSV is
|
| 57 |
+
# the bug we fixed.
|
| 58 |
+
# Make sure it's not followed by load_event_from_parquet before
|
| 59 |
+
# "if use_csv:" appears
|
| 60 |
+
assert idx_parquet_if > idx_csv or "load_event_from_parquet" not in source[idx_parquet_if:idx_csv], (
|
| 61 |
+
"Found 'if use_parquet:' with load_event_from_parquet before "
|
| 62 |
+
"'if use_csv:' — this is the bug where parquet takes priority "
|
| 63 |
+
"over CSV edits"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def test_parse_csv_event_logic():
|
| 68 |
+
"""_parse_csv_event should correctly build event dicts from CSV text.
|
| 69 |
+
|
| 70 |
+
We inline the same parsing logic used by app.py to avoid importing
|
| 71 |
+
the module (which requires heavy dependencies like gradio).
|
| 72 |
+
"""
|
| 73 |
+
import io
|
| 74 |
+
import numpy as np
|
| 75 |
+
import pandas as pd
|
| 76 |
+
|
| 77 |
+
def _read(text, min_cols=1):
|
| 78 |
+
if not text or not text.strip():
|
| 79 |
+
return np.zeros((0, min_cols), dtype=np.float64)
|
| 80 |
+
df = pd.read_csv(io.StringIO(text), header=None)
|
| 81 |
+
return df.values.astype(np.float64)
|
| 82 |
+
|
| 83 |
+
def _parse_csv_event(csv_hits, csv_tracks, csv_particles, csv_pandora=""):
|
| 84 |
+
hits_arr = _read(csv_hits, 11)
|
| 85 |
+
tracks_arr = _read(csv_tracks, 25)
|
| 86 |
+
particles_arr = _read(csv_particles, 18)
|
| 87 |
+
pandora_arr = _read(csv_pandora, 9)
|
| 88 |
+
if tracks_arr.shape[1] < 25 and tracks_arr.shape[0] > 0:
|
| 89 |
+
pad = np.zeros((tracks_arr.shape[0], 25 - tracks_arr.shape[1]))
|
| 90 |
+
tracks_arr = np.concatenate([tracks_arr, pad], axis=1)
|
| 91 |
+
ygen_hit = np.full(len(hits_arr), -1, dtype=np.int64)
|
| 92 |
+
ygen_track = np.full(len(tracks_arr), -1, dtype=np.int64)
|
| 93 |
+
return {
|
| 94 |
+
"X_hit": hits_arr,
|
| 95 |
+
"X_track": tracks_arr,
|
| 96 |
+
"X_gen": particles_arr,
|
| 97 |
+
"X_pandora": pandora_arr,
|
| 98 |
+
"ygen_hit": ygen_hit,
|
| 99 |
+
"ygen_track": ygen_track,
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
# Basic parse
|
| 103 |
+
csv_hits = "0,0,0,0,0,1.23,1800.5,200.3,100.1,0,1"
|
| 104 |
+
event = _parse_csv_event(csv_hits, "", "", "")
|
| 105 |
+
assert event["X_hit"].shape == (1, 11)
|
| 106 |
+
assert event["X_track"].shape == (0, 25)
|
| 107 |
+
assert np.isclose(event["X_hit"][0, 5], 1.23)
|
| 108 |
+
|
| 109 |
+
# Empty tracks after removing them
|
| 110 |
+
event2 = _parse_csv_event(csv_hits, "", "", "")
|
| 111 |
+
assert event2["X_track"].shape[0] == 0
|
| 112 |
+
|
| 113 |
+
# Two tracks vs one track
|
| 114 |
+
csv_tracks_two = (
|
| 115 |
+
"1,0,0,0,0,5.0,3.0,2.0,3.3,0,0,0,1800.0,150.0,90.0,12.5,8,0,0,0,0,0,2.9,1.9,3.2\n"
|
| 116 |
+
"1,0,0,0,0,3.0,1.0,1.5,2.1,0,0,0,1700.0,100.0,80.0,10.0,6,0,0,0,0,0,0.9,1.4,2.0"
|
| 117 |
+
)
|
| 118 |
+
csv_tracks_one = (
|
| 119 |
+
"1,0,0,0,0,5.0,3.0,2.0,3.3,0,0,0,1800.0,150.0,90.0,12.5,8,0,0,0,0,0,2.9,1.9,3.2"
|
| 120 |
+
)
|
| 121 |
+
event_two = _parse_csv_event(csv_hits, csv_tracks_two, "", "")
|
| 122 |
+
event_one = _parse_csv_event(csv_hits, csv_tracks_one, "", "")
|
| 123 |
+
assert event_two["X_track"].shape[0] == 2
|
| 124 |
+
assert event_one["X_track"].shape[0] == 1
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def test_input_source_decision_logic():
|
| 128 |
+
"""Simulate the decision logic from run_inference_ui and verify that
|
| 129 |
+
CSV is used even when a parquet path is present."""
|
| 130 |
+
|
| 131 |
+
def decide_source(parquet_path, csv_hits):
|
| 132 |
+
"""Mirrors the decision logic in run_inference_ui."""
|
| 133 |
+
use_parquet = parquet_path and os.path.isfile(parquet_path)
|
| 134 |
+
use_csv = bool(csv_hits and csv_hits.strip())
|
| 135 |
+
|
| 136 |
+
if use_csv:
|
| 137 |
+
return "csv"
|
| 138 |
+
elif use_parquet:
|
| 139 |
+
return "parquet"
|
| 140 |
+
else:
|
| 141 |
+
return "none"
|
| 142 |
+
|
| 143 |
+
# CSV present + parquet path present → should use CSV
|
| 144 |
+
# (use this script as a stand-in for an existing file)
|
| 145 |
+
existing_file = os.path.abspath(__file__)
|
| 146 |
+
assert decide_source(existing_file, "some,csv,data") == "csv"
|
| 147 |
+
|
| 148 |
+
# CSV present + no parquet → should use CSV
|
| 149 |
+
assert decide_source("", "some,csv,data") == "csv"
|
| 150 |
+
|
| 151 |
+
# CSV empty + parquet present → should use parquet
|
| 152 |
+
assert decide_source(existing_file, "") == "parquet"
|
| 153 |
+
|
| 154 |
+
# Both empty → none
|
| 155 |
+
assert decide_source("", "") == "none"
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
+
test_csv_checked_before_parquet()
|
| 160 |
+
test_parse_csv_event_logic()
|
| 161 |
+
test_input_source_decision_logic()
|
| 162 |
+
print("All tests passed.")
|
tests/test_energy_correction_no_matches.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests that energy correction runs even when no MC-truth showers are matched.
|
| 2 |
+
|
| 3 |
+
The bug: ``_run_energy_correction`` returned early (``if not graphs_matched:
|
| 4 |
+
return particles_df``) whenever no predicted cluster could be matched to a
|
| 5 |
+
true particle. In pure inference mode (no MC truth) *all* clusters are
|
| 6 |
+
"fakes" and ``graphs_matched`` is always empty, so the correction was never
|
| 7 |
+
applied and the output table only contained the basic ``energy_sum_hits`` /
|
| 8 |
+
``p_track`` columns.
|
| 9 |
+
|
| 10 |
+
The fix: only bail out when *both* ``graphs_matched`` **and** ``graphs_fakes``
|
| 11 |
+
are empty (i.e. there are literally no clusters to correct).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import ast
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _get_function_source(path, func_name):
|
| 19 |
+
"""Return the source of a top-level function from *path*."""
|
| 20 |
+
with open(path) as f:
|
| 21 |
+
source = f.read()
|
| 22 |
+
tree = ast.parse(source)
|
| 23 |
+
lines = source.splitlines(keepends=True)
|
| 24 |
+
for node in tree.body:
|
| 25 |
+
if isinstance(node, ast.FunctionDef) and node.name == func_name:
|
| 26 |
+
return "".join(lines[node.lineno - 1 : node.end_lineno])
|
| 27 |
+
raise ValueError(f"{func_name} not found in {path}")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
INFERENCE_PATH = os.path.join(
|
| 31 |
+
os.path.dirname(__file__), "..", "src", "inference.py"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_early_return_requires_both_empty():
|
| 36 |
+
"""The early return must check both graphs_matched *and* graphs_fakes.
|
| 37 |
+
|
| 38 |
+
The old (buggy) guard was:
|
| 39 |
+
if not graphs_matched:
|
| 40 |
+
return particles_df
|
| 41 |
+
|
| 42 |
+
The fixed guard must be:
|
| 43 |
+
if not graphs_matched and not graphs_fakes:
|
| 44 |
+
return particles_df
|
| 45 |
+
"""
|
| 46 |
+
src = _get_function_source(INFERENCE_PATH, "_run_energy_correction")
|
| 47 |
+
|
| 48 |
+
# The buggy single-condition early return must NOT appear
|
| 49 |
+
assert "if not graphs_matched:\n return particles_df" not in src, (
|
| 50 |
+
"Found the old single-condition early return 'if not graphs_matched'; "
|
| 51 |
+
"energy correction would be skipped whenever no MC-truth matches exist."
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# The correct two-condition guard must be present
|
| 55 |
+
assert "if not graphs_matched and not graphs_fakes:" in src, (
|
| 56 |
+
"Expected 'if not graphs_matched and not graphs_fakes:' in "
|
| 57 |
+
"_run_energy_correction but did not find it."
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def test_true_energies_t_not_called_with_cat_on_empty():
|
| 62 |
+
"""``torch.cat(true_energies, dim=0)`` must not appear unconditionally.
|
| 63 |
+
|
| 64 |
+
When ``graphs_matched`` is empty, ``true_energies`` is an empty list and
|
| 65 |
+
``torch.cat([], dim=0)`` raises a RuntimeError. The fixed code removes
|
| 66 |
+
this line entirely (the variable was unused anyway).
|
| 67 |
+
"""
|
| 68 |
+
src = _get_function_source(INFERENCE_PATH, "_run_energy_correction")
|
| 69 |
+
|
| 70 |
+
# Either the assignment is gone, or it is guarded
|
| 71 |
+
if "true_energies_t = torch.cat(true_energies" in src:
|
| 72 |
+
# If it still exists it must be guarded by an if-statement
|
| 73 |
+
lines = src.splitlines()
|
| 74 |
+
for i, line in enumerate(lines):
|
| 75 |
+
if "true_energies_t = torch.cat(true_energies" in line:
|
| 76 |
+
# Check that a guard exists somewhere before this line
|
| 77 |
+
guard_present = any(
|
| 78 |
+
"if true_energies" in lines[j] or "if graphs_matched" in lines[j]
|
| 79 |
+
for j in range(max(0, i - 5), i)
|
| 80 |
+
)
|
| 81 |
+
assert guard_present, (
|
| 82 |
+
f"Line {i}: unguarded 'torch.cat(true_energies)' would "
|
| 83 |
+
"raise RuntimeError on empty list when no showers match."
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
if __name__ == "__main__":
|
| 88 |
+
test_early_return_requires_both_empty()
|
| 89 |
+
test_true_energies_t_not_called_with_cat_on_empty()
|
| 90 |
+
print("All tests passed.")
|
tests/test_pfo_links.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the hit → Pandora cluster mapping (PFO links) field.
|
| 2 |
+
|
| 3 |
+
Validates that:
|
| 4 |
+
1. _parse_csv_event correctly parses the csv_pfo_links parameter.
|
| 5 |
+
2. PFO links are gracefully handled when CSV is modified (partial matches).
|
| 6 |
+
3. The _load_event_into_csv function includes PFO links output.
|
| 7 |
+
4. The run_inference_ui function accepts the csv_pfo_links parameter.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import io
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
# Inline the parsing logic to avoid importing app.py (heavy dependencies)
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
def _parse_csv_event(csv_hits, csv_tracks, csv_particles, csv_pandora="", csv_pfo_links=""):
|
| 21 |
+
"""Mirror of the _parse_csv_event logic from app.py."""
|
| 22 |
+
|
| 23 |
+
def _read(text, min_cols=1):
|
| 24 |
+
if not text or not text.strip():
|
| 25 |
+
return np.zeros((0, min_cols), dtype=np.float64)
|
| 26 |
+
df = pd.read_csv(io.StringIO(text), header=None)
|
| 27 |
+
return df.values.astype(np.float64)
|
| 28 |
+
|
| 29 |
+
hits_arr = _read(csv_hits, 11)
|
| 30 |
+
tracks_arr = _read(csv_tracks, 25)
|
| 31 |
+
particles_arr = _read(csv_particles, 18)
|
| 32 |
+
pandora_arr = _read(csv_pandora, 9)
|
| 33 |
+
if tracks_arr.shape[1] < 25 and tracks_arr.shape[0] > 0:
|
| 34 |
+
pad = np.zeros((tracks_arr.shape[0], 25 - tracks_arr.shape[1]))
|
| 35 |
+
tracks_arr = np.concatenate([tracks_arr, pad], axis=1)
|
| 36 |
+
ygen_hit = np.full(len(hits_arr), -1, dtype=np.int64)
|
| 37 |
+
ygen_track = np.full(len(tracks_arr), -1, dtype=np.int64)
|
| 38 |
+
|
| 39 |
+
# Parse PFO link arrays
|
| 40 |
+
pfo_calohit = np.array([], dtype=np.int64)
|
| 41 |
+
pfo_track = np.array([], dtype=np.int64)
|
| 42 |
+
if csv_pfo_links and csv_pfo_links.strip():
|
| 43 |
+
lines = csv_pfo_links.strip().split("\n")
|
| 44 |
+
if len(lines) >= 1 and lines[0].strip():
|
| 45 |
+
pfo_calohit = np.array(
|
| 46 |
+
[int(v) for v in lines[0].strip().split(",")], dtype=np.int64
|
| 47 |
+
)
|
| 48 |
+
if len(lines) >= 2 and lines[1].strip():
|
| 49 |
+
pfo_track = np.array(
|
| 50 |
+
[int(v) for v in lines[1].strip().split(",")], dtype=np.int64
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
return {
|
| 54 |
+
"X_hit": hits_arr,
|
| 55 |
+
"X_track": tracks_arr,
|
| 56 |
+
"X_gen": particles_arr,
|
| 57 |
+
"X_pandora": pandora_arr,
|
| 58 |
+
"ygen_hit": ygen_hit,
|
| 59 |
+
"ygen_track": ygen_track,
|
| 60 |
+
"pfo_calohit": pfo_calohit,
|
| 61 |
+
"pfo_track": pfo_track,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# ---------------------------------------------------------------------------
|
| 66 |
+
# Tests
|
| 67 |
+
# ---------------------------------------------------------------------------
|
| 68 |
+
|
| 69 |
+
def test_parse_pfo_links_basic():
|
| 70 |
+
"""PFO links should be correctly parsed from csv_pfo_links."""
|
| 71 |
+
csv_hits = "0,0,0,0,0,1.23,1800.5,200.3,100.1,0,1\n0,0,0,0,0,0.45,1900.2,-50.1,300.7,0,2"
|
| 72 |
+
csv_pfo_links = "3,5\n7"
|
| 73 |
+
|
| 74 |
+
event = _parse_csv_event(csv_hits, "", "", "", csv_pfo_links)
|
| 75 |
+
|
| 76 |
+
assert "pfo_calohit" in event
|
| 77 |
+
assert "pfo_track" in event
|
| 78 |
+
np.testing.assert_array_equal(event["pfo_calohit"], [3, 5])
|
| 79 |
+
np.testing.assert_array_equal(event["pfo_track"], [7])
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def test_parse_pfo_links_empty():
|
| 83 |
+
"""Empty csv_pfo_links should produce empty arrays."""
|
| 84 |
+
csv_hits = "0,0,0,0,0,1.23,1800.5,200.3,100.1,0,1"
|
| 85 |
+
event = _parse_csv_event(csv_hits, "", "", "", "")
|
| 86 |
+
|
| 87 |
+
assert len(event["pfo_calohit"]) == 0
|
| 88 |
+
assert len(event["pfo_track"]) == 0
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def test_parse_pfo_links_calohit_only():
|
| 92 |
+
"""Only calohit line provided (no track line)."""
|
| 93 |
+
csv_pfo_links = "1,2,-1,3"
|
| 94 |
+
event = _parse_csv_event("0,0,0,0,0,1.0,1.0,1.0,1.0,0,1", "", "", "", csv_pfo_links)
|
| 95 |
+
|
| 96 |
+
np.testing.assert_array_equal(event["pfo_calohit"], [1, 2, -1, 3])
|
| 97 |
+
assert len(event["pfo_track"]) == 0
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def test_parse_pfo_links_with_negatives():
|
| 101 |
+
"""PFO links should correctly handle -1 values (unassigned hits)."""
|
| 102 |
+
csv_pfo_links = "3,-1,5,-1\n-1,2"
|
| 103 |
+
event = _parse_csv_event("", "", "", "", csv_pfo_links)
|
| 104 |
+
|
| 105 |
+
np.testing.assert_array_equal(event["pfo_calohit"], [3, -1, 5, -1])
|
| 106 |
+
np.testing.assert_array_equal(event["pfo_track"], [-1, 2])
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def test_pandora_cluster_partial_match():
|
| 110 |
+
"""When CSV is modified (fewer hits than PFO links), use min of lengths."""
|
| 111 |
+
# Simulate the assignment logic from inference.py
|
| 112 |
+
n_calo = 3 # only 3 hits now
|
| 113 |
+
n_tracks = 1 # only 1 track now
|
| 114 |
+
n_hits = n_calo + n_tracks
|
| 115 |
+
|
| 116 |
+
pfo_calohit = np.array([0, 1, 2, 3, 4], dtype=np.int64) # originally 5 hits
|
| 117 |
+
pfo_track = np.array([5, 6], dtype=np.int64) # originally 2 tracks
|
| 118 |
+
|
| 119 |
+
pandora_cluster_ids = np.full(n_hits, -1, dtype=np.int64)
|
| 120 |
+
if len(pfo_calohit) > 0:
|
| 121 |
+
n_assign = min(len(pfo_calohit), n_calo)
|
| 122 |
+
pandora_cluster_ids[:n_assign] = pfo_calohit[:n_assign]
|
| 123 |
+
if n_tracks > 0 and len(pfo_track) > 0:
|
| 124 |
+
n_assign = min(len(pfo_track), n_tracks)
|
| 125 |
+
pandora_cluster_ids[n_calo:n_calo + n_assign] = pfo_track[:n_assign]
|
| 126 |
+
|
| 127 |
+
# First 3 calo hits should get their PFO IDs, 4th hit (track) gets first track PFO
|
| 128 |
+
np.testing.assert_array_equal(pandora_cluster_ids, [0, 1, 2, 5])
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def test_pandora_cluster_no_links():
|
| 132 |
+
"""When no PFO links are available, all pandora_cluster_ids should be -1."""
|
| 133 |
+
n_hits = 5
|
| 134 |
+
n_calo = 3
|
| 135 |
+
n_tracks = 2
|
| 136 |
+
|
| 137 |
+
pfo_calohit = np.array([], dtype=np.int64)
|
| 138 |
+
pfo_track = np.array([], dtype=np.int64)
|
| 139 |
+
|
| 140 |
+
pandora_cluster_ids = np.full(n_hits, -1, dtype=np.int64)
|
| 141 |
+
if len(pfo_calohit) > 0:
|
| 142 |
+
n_assign = min(len(pfo_calohit), n_calo)
|
| 143 |
+
pandora_cluster_ids[:n_assign] = pfo_calohit[:n_assign]
|
| 144 |
+
if n_tracks > 0 and len(pfo_track) > 0:
|
| 145 |
+
n_assign = min(len(pfo_track), n_tracks)
|
| 146 |
+
pandora_cluster_ids[n_calo:n_calo + n_assign] = pfo_track[:n_assign]
|
| 147 |
+
|
| 148 |
+
np.testing.assert_array_equal(pandora_cluster_ids, [-1, -1, -1, -1, -1])
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def test_pandora_cluster_more_hits_than_links():
|
| 152 |
+
"""When more hits exist than PFO links, extra hits get -1."""
|
| 153 |
+
n_calo = 5
|
| 154 |
+
n_tracks = 2
|
| 155 |
+
n_hits = n_calo + n_tracks
|
| 156 |
+
|
| 157 |
+
pfo_calohit = np.array([1, 2], dtype=np.int64) # only 2 links for 5 hits
|
| 158 |
+
pfo_track = np.array([3], dtype=np.int64) # only 1 link for 2 tracks
|
| 159 |
+
|
| 160 |
+
pandora_cluster_ids = np.full(n_hits, -1, dtype=np.int64)
|
| 161 |
+
if len(pfo_calohit) > 0:
|
| 162 |
+
n_assign = min(len(pfo_calohit), n_calo)
|
| 163 |
+
pandora_cluster_ids[:n_assign] = pfo_calohit[:n_assign]
|
| 164 |
+
if n_tracks > 0 and len(pfo_track) > 0:
|
| 165 |
+
n_assign = min(len(pfo_track), n_tracks)
|
| 166 |
+
pandora_cluster_ids[n_calo:n_calo + n_assign] = pfo_track[:n_assign]
|
| 167 |
+
|
| 168 |
+
np.testing.assert_array_equal(pandora_cluster_ids, [1, 2, -1, -1, -1, 3, -1])
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def test_app_source_has_csv_pfo_links_field():
|
| 172 |
+
"""app.py should have the csv_pfo_links text field wired up."""
|
| 173 |
+
app_path = os.path.join(os.path.dirname(__file__), "..", "app.py")
|
| 174 |
+
with open(app_path) as f:
|
| 175 |
+
source = f.read()
|
| 176 |
+
|
| 177 |
+
assert "csv_pfo_links" in source, "app.py should reference csv_pfo_links"
|
| 178 |
+
assert "Hit → Pandora Cluster links" in source, (
|
| 179 |
+
"app.py should have the PFO links text field label"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def test_run_inference_ui_accepts_pfo_links():
|
| 184 |
+
"""run_inference_ui should accept csv_pfo_links as a parameter."""
|
| 185 |
+
import ast
|
| 186 |
+
app_path = os.path.join(os.path.dirname(__file__), "..", "app.py")
|
| 187 |
+
with open(app_path) as f:
|
| 188 |
+
tree = ast.parse(f.read())
|
| 189 |
+
|
| 190 |
+
for node in ast.walk(tree):
|
| 191 |
+
if isinstance(node, ast.FunctionDef) and node.name == "run_inference_ui":
|
| 192 |
+
arg_names = [arg.arg for arg in node.args.args]
|
| 193 |
+
assert "csv_pfo_links" in arg_names, (
|
| 194 |
+
"run_inference_ui should accept csv_pfo_links parameter"
|
| 195 |
+
)
|
| 196 |
+
return
|
| 197 |
+
raise AssertionError("Could not find run_inference_ui function in app.py")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def test_load_event_returns_pfo_links():
|
| 201 |
+
"""_load_event_into_csv error path should return 6 values (including PFO links)."""
|
| 202 |
+
import ast
|
| 203 |
+
app_path = os.path.join(os.path.dirname(__file__), "..", "app.py")
|
| 204 |
+
with open(app_path) as f:
|
| 205 |
+
tree = ast.parse(f.read())
|
| 206 |
+
|
| 207 |
+
for node in ast.walk(tree):
|
| 208 |
+
if isinstance(node, ast.FunctionDef) and node.name == "_load_event_into_csv":
|
| 209 |
+
# Check return statements in the function body
|
| 210 |
+
for child in ast.walk(node):
|
| 211 |
+
if isinstance(child, ast.Return) and isinstance(child.value, ast.Tuple):
|
| 212 |
+
n_elts = len(child.value.elts)
|
| 213 |
+
assert n_elts == 6, (
|
| 214 |
+
f"_load_event_into_csv should return 6 values, got {n_elts}"
|
| 215 |
+
)
|
| 216 |
+
return
|
| 217 |
+
raise AssertionError("Could not find _load_event_into_csv function in app.py")
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
if __name__ == "__main__":
|
| 221 |
+
test_parse_pfo_links_basic()
|
| 222 |
+
test_parse_pfo_links_empty()
|
| 223 |
+
test_parse_pfo_links_calohit_only()
|
| 224 |
+
test_parse_pfo_links_with_negatives()
|
| 225 |
+
test_pandora_cluster_partial_match()
|
| 226 |
+
test_pandora_cluster_no_links()
|
| 227 |
+
test_pandora_cluster_more_hits_than_links()
|
| 228 |
+
test_app_source_has_csv_pfo_links_field()
|
| 229 |
+
test_run_inference_ui_accepts_pfo_links()
|
| 230 |
+
test_load_event_returns_pfo_links()
|
| 231 |
+
print("All PFO links tests passed.")
|