hawthorneluke commited on
Commit
8cfa91c
·
verified ·
1 Parent(s): 4b20df3

Production LightGBM model — 49 features, 5-fold GroupKFold CV

Browse files
Files changed (5) hide show
  1. README.md +138 -18
  2. features.py +65 -0
  3. model_days.pkl +3 -0
  4. model_info.json +164 -0
  5. model_mag.pkl +3 -0
README.md CHANGED
@@ -1,24 +1,144 @@
1
  ---
2
- tags: [astronomy, supernova, regression, ztf, timm, convnext]
3
- datasets: [MultimodalUniverse/btsbot]
 
 
 
 
 
 
 
 
4
  license: mit
 
5
  ---
6
- # Supernova Peak Predictor
7
- Predicts **when** and **how bright** a supernova will become from early ZTF observations.
8
 
9
- ## Novel Contribution
10
- First model to predict `days_to_peak` and `peakmag` as regression targets from rise-phase alerts.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  ## Architecture
13
- ConvNeXt-pico (Galaxy Zoo pretrained) + 23-feature metadata MLP -> fusion -> dual regression heads
14
- 8,811,816 params (8,425,576 trainable)
15
-
16
- ## Test Results
17
- | Metric | days_to_peak | peakmag |
18
- |--------|-------------|---------|
19
- | MAE | 124.03 days | 0.324 mag |
20
- | Median AE | 29.15 days | 0.223 mag |
21
-
22
- ## Data
23
- [MultimodalUniverse/btsbot](https://huggingface.co/datasets/MultimodalUniverse/btsbot) — rise-phase supernovae only
24
- 3x63x63 image triplets + 23 metadata features. Leaky features (age, days_since_peak) excluded.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ tags:
3
+ - astronomy
4
+ - supernova
5
+ - regression
6
+ - ztf
7
+ - lightgbm
8
+ - tabular
9
+ - time-domain
10
+ datasets:
11
+ - MultimodalUniverse/btsbot
12
  license: mit
13
+ pipeline_tag: tabular-regression
14
  ---
 
 
15
 
16
+ # 🌟 Supernova Peak Predictor
17
+
18
+ **Predicts when and how bright a supernova will become from its earliest ZTF alert observations.**
19
+
20
+ Given a single ZTF alert packet from a rising supernova, this model predicts:
21
+ - **`days_to_peak`** — how many days until the supernova reaches maximum brightness
22
+ - **`peakmag`** — the peak apparent magnitude it will achieve
23
+
24
+ This enables astronomers to answer: *"Should I point the telescope at this target tonight, or can it wait?"*
25
+
26
+ ## Why this matters
27
+
28
+ The bottleneck in transient astronomy isn't detection — ZTF finds thousands of candidates per night. The bottleneck is **follow-up telescope time**. Spectroscopic observations are expensive and limited. Every night, astronomers must decide which of dozens of candidates to prioritize.
29
+
30
+ Currently, that decision is reactive: *"Is this a supernova? Yes → follow up."* ([BTSbot](https://huggingface.co/nabeelr/BTSbot-convnext-pico-galaxyzoo-metadata) solves this with 98.5% accuracy.)
31
+
32
+ Our model makes it **proactive**: *"This SN will peak at mag 17.8 in 5 days → schedule it now"* vs *"This one won't peak for 3 months → deprioritize."*
33
+
34
+ With the Vera C. Rubin Observatory (LSST) coming online, alert rates will jump from ~100K/night to ~10M/night. Automated triage like this will be essential.
35
+
36
+ ## Performance
37
+
38
+ Evaluated with **5-fold grouped cross-validation** (grouped by supernova object ID to prevent data leakage — no alerts from the same SN appear in both train and validation).
39
+
40
+ ### Overall (27,202 alerts from 3,806 supernovae)
41
+
42
+ | Target | MAE | Median AE | P90 |
43
+ |--------|-----|-----------|-----|
44
+ | **days_to_peak** | 118.8 days | 29.8 days | 327.2 days |
45
+ | **peakmag** | 0.257 mag | 0.178 mag | 0.536 mag |
46
+
47
+ ### By number of prior detections
48
+
49
+ | Detection stage | n | MAE days | Median days | MAE mag | Median mag |
50
+ |----------------|---|----------|-------------|---------|------------|
51
+ | **1-3 (first catches)** | 4,817 | 70.8 | **16.5** | 0.391 | 0.304 |
52
+ | 4-10 (early rise) | 9,683 | 114.4 | 26.7 | 0.273 | 0.197 |
53
+ | 11-50 (sampled rise) | 11,029 | 132.4 | 36.3 | 0.199 | 0.147 |
54
+ | 50+ (monitored) | 1,673 | 192.0 | 77.6 | 0.164 | 0.117 |
55
+
56
+ **Key finding:** The model is most useful on the hardest, most valuable cases — the first 1-3 detections — where median timing error is just **16.5 days** and median magnitude error is **0.304 mag**.
57
+
58
+ ### By true time-to-peak
59
+
60
+ | Horizon | n | MAE days | Median days | MAE mag | Median mag |
61
+ |---------|---|----------|-------------|---------|------------|
62
+ | **Imminent (<7d)** | 12,694 | 66.2 | 24.9 | 0.330 | 0.229 |
63
+ | **Soon (7-30d)** | 10,431 | 62.9 | **20.7** | **0.183** | **0.150** |
64
+ | Weeks (30-100d) | 1,086 | 75.5 | 30.2 | 0.211 | 0.148 |
65
+ | Distant (100d+) | 2,991 | 552.5 | 433.0 | 0.229 | 0.164 |
66
+
67
+ The **"soon" horizon (7-30 days)** is the sweet spot — exactly the window where scheduling decisions matter most, and where the model achieves **0.150 mag** median error.
68
 
69
  ## Architecture
70
+
71
+ **LightGBM** gradient boosted trees on **49 engineered features** extracted from ZTF alert metadata.
72
+
73
+ No images are used. We tested a ConvNeXt-pico CNN on the 63×63 difference image triplets and found that **metadata alone outperforms the full multimodal model** (MAE mag 0.257 vs 0.271). The images add noise at this resolution. This is itself a useful finding — it means the predictor can run at alert-stream speed (microseconds per prediction, no GPU needed).
74
+
75
+ ### Top features (by LightGBM importance)
76
+
77
+ **For days_to_peak:** `ncovhist` (coverage history), `distpsnr1` (distance to nearest PS1 source), `distpsnr2`, `neargaia`, `maxmag_so_far`
78
+
79
+ **For peakmag:** `maggaia` (Gaia magnitude), `peakmag_so_far` (brightest seen), `sgscore1` (star/galaxy score), `maxmag_so_far`, `ndethist`
80
+
81
+ The host galaxy properties (PS1 colors, star/galaxy scores, distances) dominate the timing prediction — the model is learning that **where a supernova lives** (host type, distance, environment) constrains **how it evolves**.
82
+
83
+ ## What we tried that didn't work
84
+
85
+ | Approach | Result |
86
+ |----------|--------|
87
+ | ConvNeXt-pico CNN on 63×63 ZTF image triplets + metadata | 6-20% **worse** than metadata-only across all bins |
88
+ | Simple MLP (114K params) on 23 raw features | Competitive but 2-5% worse than tree models with engineered features |
89
+ | Including `age` / `days_since_peak` as input features | Creates direct data leakage (`days_to_peak = age - days_since_peak`) |
90
+
91
+ ## Training data
92
+
93
+ - **Source:** [MultimodalUniverse/btsbot](https://huggingface.co/datasets/MultimodalUniverse/btsbot) — ZTF Bright Transient Survey alerts
94
+ - **Filter:** Rise-phase supernovae only (`is_rise=True`, `is_SN=True`)
95
+ - **Size:** 27,202 alerts from 3,806 unique supernovae
96
+ - **Splits:** 5-fold GroupKFold by object ID (no alert-level leakage)
97
+
98
+ ## Usage
99
+
100
+ ```python
101
+ import pickle, json
102
+ import numpy as np
103
+ from huggingface_hub import hf_hub_download
104
+
105
+ # Download model files
106
+ model_days = pickle.load(open(hf_hub_download("hawthorneluke/supernova-peak-predictor", "model_days.pkl"), "rb"))
107
+ model_mag = pickle.load(open(hf_hub_download("hawthorneluke/supernova-peak-predictor", "model_mag.pkl"), "rb"))
108
+ info = json.load(open(hf_hub_download("hawthorneluke/supernova-peak-predictor", "model_info.json")))
109
+
110
+ # Your ZTF alert metadata (example)
111
+ from features import engineer_features # download features.py from this repo
112
+ alert = {"magpsf": 19.2, "sigmapsf": 0.15, "ndethist": 3, ...} # ZTF alert fields
113
+ feats = engineer_features(alert)
114
+
115
+ # Predict
116
+ X = np.array([[feats[c] for c in info['feature_cols']]], dtype=np.float32)
117
+ days_pred = model_days.predict(X)[0]
118
+ mag_pred = model_mag.predict(X)[0]
119
+ print(f"Predicted: peak in {days_pred:.1f} days at magnitude {mag_pred:.2f}")
120
+ ```
121
+
122
+ ## Interactive demo
123
+
124
+ Try it live: [hawthorneluke/supernova-peak-predictor-demo](https://huggingface.co/spaces/hawthorneluke/supernova-peak-predictor-demo)
125
+
126
+ ## Limitations
127
+
128
+ - **Long-horizon predictions are poor.** For SNe >100 days from peak, the MAE is 552 days. The model essentially can't predict these — they're rare, slow-evolving transients with ambiguous early signatures.
129
+ - **No uncertainty quantification.** The model gives point estimates. A production system would need prediction intervals.
130
+ - **ZTF-specific.** Features are tied to ZTF alert schema. Adaptation to LSST/Rubin alerts would require feature remapping.
131
+ - **No spectroscopic type prediction.** We predict timing and brightness but not SN type (Ia vs II vs Ibc). This would be a natural extension.
132
+
133
+ ## Citation
134
+
135
+ If you use this model, please cite the underlying data:
136
+
137
+ ```bibtex
138
+ @article{rehemtulla2024btsbot,
139
+ title={BTSbot: A Multi-modal Deep Learning Model for Automated Bright Transient Identification},
140
+ author={Rehemtulla, Nabeel and others},
141
+ journal={arXiv preprint arXiv:2401.15167},
142
+ year={2024}
143
+ }
144
+ ```
features.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Feature engineering for Supernova Peak Predictor."""
2
+
3
+ def engineer_features(row):
4
+ """Extract features from a ZTF alert metadata dict.
5
+
6
+ Args:
7
+ row: dict with ZTF alert fields (magpsf, sigmapsf, etc.)
8
+ Returns:
9
+ dict of engineered features
10
+ """
11
+ feats = {}
12
+ feats['magpsf'] = float(row.get('magpsf', 0) or 0)
13
+ feats['sigmapsf'] = float(row.get('sigmapsf', 0) or 0)
14
+ feats['magap'] = float(row.get('magap', 0) or 0)
15
+ feats['sigmagap'] = float(row.get('sigmagap', 0) or 0)
16
+ feats['diffmaglim'] = float(row.get('diffmaglim', 0) or 0)
17
+ feats['peakmag_so_far'] = float(row.get('peakmag_so_far', 0) or 0)
18
+ feats['maxmag_so_far'] = float(row.get('maxmag_so_far', 0) or 0)
19
+ feats['mag_range'] = feats['maxmag_so_far'] - feats['peakmag_so_far']
20
+ feats['mag_vs_peak'] = feats['magpsf'] - feats['peakmag_so_far']
21
+ feats['mag_vs_lim'] = feats['diffmaglim'] - feats['magpsf']
22
+ feats['mag_psf_ap_diff'] = feats['magpsf'] - feats['magap']
23
+ feats['ndethist'] = float(row.get('ndethist', 0) or 0)
24
+ feats['ncovhist'] = float(row.get('ncovhist', 0) or 0)
25
+ feats['nnotdet'] = float(row.get('nnotdet', 0) or 0)
26
+ feats['nmtchps'] = float(row.get('nmtchps', 0) or 0)
27
+ feats['det_fraction'] = feats['ndethist'] / (feats['ncovhist'] + 1)
28
+ feats['N'] = float(row.get('N', 0) or 0)
29
+ feats['nneg'] = float(row.get('nneg', 0) or 0)
30
+ feats['nbad'] = float(row.get('nbad', 0) or 0)
31
+ feats['fwhm'] = float(row.get('fwhm', 0) or 0)
32
+ feats['chipsf'] = float(row.get('chipsf', 0) or 0)
33
+ feats['chinr'] = float(row.get('chinr', 0) or 0)
34
+ feats['sharpnr'] = float(row.get('sharpnr', 0) or 0)
35
+ feats['scorr'] = float(row.get('scorr', 0) or 0)
36
+ feats['sky'] = float(row.get('sky', 0) or 0)
37
+ feats['classtar'] = float(row.get('classtar', 0) or 0)
38
+ feats['new_drb'] = float(row.get('new_drb', 0) or 0)
39
+ feats['drb'] = float(row.get('drb', 0) or 0)
40
+ feats['exptime'] = float(row.get('exptime', 30) or 30)
41
+ feats['sgscore1'] = float(row.get('sgscore1', 0) or 0)
42
+ feats['distpsnr1'] = float(row.get('distpsnr1', 0) or 0)
43
+ feats['sgscore2'] = float(row.get('sgscore2', 0) or 0)
44
+ feats['distpsnr2'] = float(row.get('distpsnr2', 0) or 0)
45
+ feats['distnr'] = float(row.get('distnr', 0) or 0)
46
+ feats['magnr'] = float(row.get('magnr', 0) or 0)
47
+ feats['mag_vs_host'] = feats['magpsf'] - feats['magnr']
48
+ feats['neargaia'] = float(row.get('neargaia', 0) or 0)
49
+ v = row.get('neargaia', 0) or 0
50
+ feats['neargaia'] = float(v) if float(v) > -998 else 0
51
+ feats['maggaia'] = float(row.get('maggaia', 0) or 0)
52
+ v = row.get('maggaia', 0) or 0
53
+ feats['maggaia'] = float(v) if float(v) > -998 else 0
54
+ for col in ['sgmag1', 'srmag1', 'simag1', 'szmag1']:
55
+ val = row.get(col, -999) or -999
56
+ feats[col] = float(val) if float(val) > -998 else 0
57
+ sg, sr, si, sz = feats['sgmag1'], feats['srmag1'], feats['simag1'], feats['szmag1']
58
+ feats['host_g_r'] = (sg - sr) if sg > 0 and sr > 0 else 0
59
+ feats['host_r_i'] = (sr - si) if sr > 0 and si > 0 else 0
60
+ feats['host_i_z'] = (si - sz) if si > 0 and sz > 0 else 0
61
+ feats['fid'] = float(row.get('fid', 1))
62
+ feats['is_g_band'] = 1.0 if feats['fid'] == 1 else 0.0
63
+ feats['ndethist_x_magrange'] = feats['ndethist'] * feats['mag_range']
64
+ feats['snr_proxy'] = feats['mag_vs_lim'] / (feats['sigmapsf'] + 0.01)
65
+ return feats
model_days.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d08ad55b1af0b994feab7a2e35437916c5521207be70ef15a0498078d2c506c
3
+ size 1382042
model_info.json ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "LightGBM",
3
+ "feature_cols": [
4
+ "magpsf",
5
+ "sigmapsf",
6
+ "magap",
7
+ "sigmagap",
8
+ "diffmaglim",
9
+ "peakmag_so_far",
10
+ "maxmag_so_far",
11
+ "mag_range",
12
+ "mag_vs_peak",
13
+ "mag_vs_lim",
14
+ "mag_psf_ap_diff",
15
+ "ndethist",
16
+ "ncovhist",
17
+ "nnotdet",
18
+ "nmtchps",
19
+ "det_fraction",
20
+ "N",
21
+ "nneg",
22
+ "nbad",
23
+ "fwhm",
24
+ "chipsf",
25
+ "chinr",
26
+ "sharpnr",
27
+ "scorr",
28
+ "sky",
29
+ "classtar",
30
+ "new_drb",
31
+ "drb",
32
+ "exptime",
33
+ "sgscore1",
34
+ "distpsnr1",
35
+ "sgscore2",
36
+ "distpsnr2",
37
+ "distnr",
38
+ "magnr",
39
+ "mag_vs_host",
40
+ "neargaia",
41
+ "maggaia",
42
+ "sgmag1",
43
+ "srmag1",
44
+ "simag1",
45
+ "szmag1",
46
+ "host_g_r",
47
+ "host_r_i",
48
+ "host_i_z",
49
+ "fid",
50
+ "is_g_band",
51
+ "ndethist_x_magrange",
52
+ "snr_proxy"
53
+ ],
54
+ "n_features": 49,
55
+ "cv_metrics": {
56
+ "mae_d": 118.77106475830078,
57
+ "mae_m": 0.2573636472225189,
58
+ "med_d": 29.818622589111328,
59
+ "med_m": 0.17830181121826172,
60
+ "p90_d": 327.1890563964844,
61
+ "p90_m": 0.5362690091133118
62
+ },
63
+ "fold_results": [
64
+ {
65
+ "fold": 1,
66
+ "xgb_mae_d": 119.88993072509766,
67
+ "xgb_mae_m": 0.2529917061328888,
68
+ "xgb_med_d": 27.38703727722168,
69
+ "xgb_med_m": 0.17678070068359375,
70
+ "lgb_mae_d": 120.7685546875,
71
+ "lgb_mae_m": 0.2521520256996155,
72
+ "lgb_med_d": 29.628095626831055,
73
+ "lgb_med_m": 0.17668533325195312
74
+ },
75
+ {
76
+ "fold": 2,
77
+ "xgb_mae_d": 114.90345001220703,
78
+ "xgb_mae_m": 0.2524799108505249,
79
+ "xgb_med_d": 31.275833129882812,
80
+ "xgb_med_m": 0.17509841918945312,
81
+ "lgb_mae_d": 111.13099670410156,
82
+ "lgb_mae_m": 0.25365331768989563,
83
+ "lgb_med_d": 27.839418411254883,
84
+ "lgb_med_m": 0.17804908752441406
85
+ },
86
+ {
87
+ "fold": 3,
88
+ "xgb_mae_d": 140.1485137939453,
89
+ "xgb_mae_m": 0.27031439542770386,
90
+ "xgb_med_d": 38.103515625,
91
+ "xgb_med_m": 0.17546653747558594,
92
+ "lgb_mae_d": 135.17355346679688,
93
+ "lgb_mae_m": 0.27064448595046997,
94
+ "lgb_med_d": 31.190265655517578,
95
+ "lgb_med_m": 0.17244434356689453
96
+ },
97
+ {
98
+ "fold": 4,
99
+ "xgb_mae_d": 115.05944061279297,
100
+ "xgb_mae_m": 0.26252108812332153,
101
+ "xgb_med_d": 31.309772491455078,
102
+ "xgb_med_m": 0.18577289581298828,
103
+ "lgb_mae_d": 114.0598373413086,
104
+ "lgb_mae_m": 0.25898680090904236,
105
+ "lgb_med_d": 29.387041091918945,
106
+ "lgb_med_m": 0.18349266052246094
107
+ },
108
+ {
109
+ "fold": 5,
110
+ "xgb_mae_d": 117.32633209228516,
111
+ "xgb_mae_m": 0.24888695776462555,
112
+ "xgb_med_d": 29.38198471069336,
113
+ "xgb_med_m": 0.18015766143798828,
114
+ "lgb_mae_d": 112.7234115600586,
115
+ "lgb_mae_m": 0.2513831853866577,
116
+ "lgb_med_d": 31.076351165771484,
117
+ "lgb_med_m": 0.18096446990966797
118
+ }
119
+ ],
120
+ "n_train": 27202,
121
+ "n_objects": 3806,
122
+ "target_stats": {
123
+ "days_mean": 99.75027465820312,
124
+ "days_std": 300.12255859375,
125
+ "days_median": 7.939062595367432,
126
+ "mag_mean": 18.47530746459961,
127
+ "mag_std": 0.8604589104652405,
128
+ "mag_median": 18.654937744140625
129
+ },
130
+ "feature_importance_days": {
131
+ "ncovhist": 1014.0,
132
+ "distpsnr1": 848.0,
133
+ "distpsnr2": 777.0,
134
+ "neargaia": 754.0,
135
+ "maxmag_so_far": 645.0,
136
+ "nnotdet": 637.0,
137
+ "host_r_i": 629.0,
138
+ "host_g_r": 595.0,
139
+ "maggaia": 543.0,
140
+ "sgscore1": 518.0,
141
+ "simag1": 510.0,
142
+ "host_i_z": 503.0,
143
+ "sgmag1": 477.0,
144
+ "distnr": 461.0,
145
+ "magnr": 451.0
146
+ },
147
+ "feature_importance_mag": {
148
+ "maggaia": 713.0,
149
+ "peakmag_so_far": 666.0,
150
+ "sgscore1": 576.0,
151
+ "maxmag_so_far": 570.0,
152
+ "ndethist": 557.0,
153
+ "det_fraction": 542.0,
154
+ "nmtchps": 534.0,
155
+ "host_i_z": 519.0,
156
+ "N": 517.0,
157
+ "sgscore2": 505.0,
158
+ "distpsnr2": 504.0,
159
+ "host_r_i": 503.0,
160
+ "host_g_r": 493.0,
161
+ "ncovhist": 490.0,
162
+ "simag1": 483.0
163
+ }
164
+ }
model_mag.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1c9944bd94483a4b3a036e1369fc77bb0d83a4c1b268701995eda1d2970ec2a
3
+ size 1461336