DTanzillo commited on
Commit
a4b5ecb
·
verified ·
1 Parent(s): c3d97e1

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y --no-install-recommends gcc && rm -rf /var/lib/apt/lists/*
7
+
8
+ # Copy requirements first for better caching
9
+ COPY requirements.txt .
10
+
11
+ # Install CPU-only PyTorch first (smaller)
12
+ RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
13
+
14
+ # Install remaining dependencies
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # Copy application code
18
+ COPY app/ ./app/
19
+ COPY src/ ./src/
20
+ COPY results/ ./results/
21
+ COPY app_wrapper.py .
22
+
23
+ # Create models directory (will be populated at runtime)
24
+ RUN mkdir -p models
25
+
26
+ # HF Spaces expects port 7860
27
+ EXPOSE 7860
28
+
29
+ # Run the wrapper that downloads models then starts uvicorn
30
+ CMD ["python", "app_wrapper.py"]
README.md CHANGED
@@ -1,10 +1,31 @@
1
  ---
2
- title: Panacea Api
3
- emoji: 🐢
4
- colorFrom: yellow
5
- colorTo: gray
6
  sdk: docker
 
7
  pinned: false
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Panacea Satellite Collision Avoidance API
3
+ colorFrom: indigo
4
+ colorTo: blue
 
5
  sdk: docker
6
+ app_port: 7860
7
  pinned: false
8
+ license: mit
9
  ---
10
 
11
+ # Panacea -- Satellite Collision Avoidance API
12
+
13
+ FastAPI backend for the Panacea satellite collision avoidance system.
14
+
15
+ ## Endpoints
16
+
17
+ - `GET /api/health` -- Health check, lists loaded models
18
+ - `POST /api/predict-conjunction` -- Run inference on a CDM sequence
19
+ - `GET /api/model-comparison` -- Pre-computed model comparison results
20
+ - `GET /api/experiment-results` -- Staleness experiment results
21
+ - `POST /api/bulk-screen` -- Screen TLE pairs for potential conjunctions
22
+
23
+ ## Models
24
+
25
+ Three models are loaded at startup from [DTanzillo/panacea-models](https://huggingface.co/DTanzillo/panacea-models):
26
+
27
+ 1. **Baseline** -- Orbital shell density prior (AUC-PR: 0.061)
28
+ 2. **XGBoost** -- Classical ML on engineered CDM features (AUC-PR: 0.988)
29
+ 3. **PI-TFT** -- Physics-Informed Temporal Fusion Transformer (AUC-PR: 0.511)
30
+
31
+ Built for AIPI 540 (Duke University).
app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Generated by Claude Code -- 2026-02-13
app/main.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by Claude Code -- 2026-02-13
2
+ """FastAPI backend for Panacea collision avoidance inference."""
3
+
4
+ import json
5
+ import os
6
+ import numpy as np
7
+ import torch
8
+ from contextlib import asynccontextmanager
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ from fastapi import FastAPI
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from pydantic import BaseModel
15
+
16
+ import sys
17
+
18
+ ROOT = Path(__file__).parent.parent
19
+ sys.path.insert(0, str(ROOT))
20
+
21
+ from src.model.baseline import OrbitalShellBaseline
22
+ from src.model.classical import XGBoostConjunctionModel
23
+ from src.model.deep import PhysicsInformedTFT
24
+ from src.model.triage import classify_urgency
25
+ from src.data.sequence_builder import TEMPORAL_FEATURES, STATIC_FEATURES, MAX_SEQ_LEN
26
+
27
+ HF_REPO_ID = "DTanzillo/panacea-models"
28
+
29
+ # Global model storage
30
+ models = {}
31
+
32
+
33
+ def download_models_from_hf(model_dir: Path, results_dir: Path):
34
+ """Download models from HuggingFace Hub if not available locally."""
35
+ try:
36
+ from huggingface_hub import snapshot_download
37
+ token = os.environ.get("HF_TOKEN")
38
+ local = snapshot_download(
39
+ HF_REPO_ID,
40
+ token=token,
41
+ allow_patterns=["models/*", "results/*"],
42
+ )
43
+ local = Path(local)
44
+ # Copy files to expected locations
45
+ for src in (local / "models").iterdir():
46
+ dst = model_dir / src.name
47
+ if not dst.exists():
48
+ import shutil
49
+ shutil.copy2(src, dst)
50
+ print(f" Downloaded {src.name} from HF Hub")
51
+ for src in (local / "results").iterdir():
52
+ dst = results_dir / src.name
53
+ if not dst.exists():
54
+ import shutil
55
+ shutil.copy2(src, dst)
56
+ print(f" Downloaded {src.name} from HF Hub")
57
+ except Exception as e:
58
+ print(f" HF Hub download skipped: {e}")
59
+
60
+
61
+ def load_models():
62
+ """Load all 3 models at startup. Downloads from HF Hub if missing."""
63
+ model_dir = ROOT / "models"
64
+ results_dir = ROOT / "results"
65
+ model_dir.mkdir(exist_ok=True)
66
+ results_dir.mkdir(exist_ok=True)
67
+
68
+ # Try downloading from HF Hub if local models are missing
69
+ if not (model_dir / "baseline.json").exists():
70
+ print(" Local models not found, trying HuggingFace Hub...")
71
+ download_models_from_hf(model_dir, results_dir)
72
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
+
74
+ baseline_path = model_dir / "baseline.json"
75
+ if baseline_path.exists():
76
+ models["baseline"] = OrbitalShellBaseline.load(baseline_path)
77
+ print(" Loaded baseline model")
78
+
79
+ xgboost_path = model_dir / "xgboost.pkl"
80
+ if xgboost_path.exists():
81
+ models["xgboost"] = XGBoostConjunctionModel.load(xgboost_path)
82
+ print(" Loaded XGBoost model")
83
+
84
+ pitft_path = model_dir / "transformer.pt"
85
+ if pitft_path.exists():
86
+ checkpoint = torch.load(pitft_path, map_location=device, weights_only=False)
87
+ config = checkpoint["config"]
88
+
89
+ model = PhysicsInformedTFT(
90
+ n_temporal_features=config["n_temporal"],
91
+ n_static_features=config["n_static"],
92
+ d_model=config.get("d_model", 128),
93
+ n_heads=config.get("n_heads", 4),
94
+ n_layers=config.get("n_layers", 2),
95
+ ).to(device)
96
+ # strict=False for backward compat: old checkpoints lack pc_head weights
97
+ model.load_state_dict(checkpoint["model_state"], strict=False)
98
+ model.eval()
99
+
100
+ models["pitft"] = model
101
+ models["pitft_checkpoint"] = checkpoint
102
+ models["pitft_device"] = device
103
+ temp = checkpoint.get("temperature", 1.0)
104
+ has_pc = checkpoint.get("has_pc_head", False)
105
+ print(f" Loaded PI-TFT (epoch {checkpoint['epoch']}, T={temp:.3f}, pc_head={'yes' if has_pc else 'no'})")
106
+
107
+
108
+ @asynccontextmanager
109
+ async def lifespan(app: FastAPI):
110
+ print("Loading models ...")
111
+ load_models()
112
+ loaded = [k for k in models if not k.startswith("pitft_")]
113
+ print(f"Models loaded: {loaded}")
114
+ yield
115
+ models.clear()
116
+
117
+
118
+ app = FastAPI(
119
+ title="Panacea — Satellite Collision Avoidance API",
120
+ version="1.0.0",
121
+ lifespan=lifespan,
122
+ )
123
+
124
+ app.add_middleware(
125
+ CORSMiddleware,
126
+ allow_origins=["*"],
127
+ allow_credentials=True,
128
+ allow_methods=["*"],
129
+ allow_headers=["*"],
130
+ )
131
+
132
+
133
+ # --- Pydantic models ---
134
+
135
+ class CDMFeatures(BaseModel):
136
+ """A sequence of CDM feature snapshots for one conjunction event."""
137
+ event_id: Optional[int] = None
138
+ cdm_sequence: list[dict]
139
+
140
+
141
+ class BulkScreenRequest(BaseModel):
142
+ """TLE data for pairwise screening."""
143
+ tles: list[dict]
144
+ top_k: int = 10
145
+
146
+
147
+ # --- Endpoints ---
148
+
149
+ @app.get("/api/health")
150
+ async def health():
151
+ loaded = []
152
+ if "baseline" in models:
153
+ loaded.append("baseline")
154
+ if "xgboost" in models:
155
+ loaded.append("xgboost")
156
+ if "pitft" in models:
157
+ loaded.append("pitft")
158
+
159
+ device = str(models.get("pitft_device", "cpu"))
160
+ return {
161
+ "status": "healthy",
162
+ "models_loaded": loaded,
163
+ "device": device,
164
+ "n_models": len(loaded),
165
+ }
166
+
167
+
168
+ @app.post("/api/predict-conjunction")
169
+ async def predict_conjunction(features: CDMFeatures):
170
+ """Run inference on a single conjunction event across all loaded models."""
171
+ results = {}
172
+ cdm_seq = features.cdm_sequence
173
+ if not cdm_seq:
174
+ return {"error": "Empty CDM sequence"}
175
+
176
+ last_cdm = cdm_seq[-1]
177
+ altitude = last_cdm.get("t_h_apo", last_cdm.get("c_h_apo", 500.0))
178
+
179
+ # Baseline prediction
180
+ if "baseline" in models:
181
+ risk_probs, miss_preds = models["baseline"].predict(np.array([altitude]))
182
+ triage = classify_urgency(float(risk_probs[0]))
183
+ results["baseline"] = {
184
+ "risk_probability": float(risk_probs[0]),
185
+ "miss_distance_km": float(np.expm1(miss_preds[0])),
186
+ "triage": {
187
+ "tier": triage.tier.value,
188
+ "color": triage.color,
189
+ "recommendation": triage.recommendation,
190
+ },
191
+ }
192
+
193
+ # XGBoost prediction
194
+ if "xgboost" in models:
195
+ xgb_features = _build_xgboost_features(cdm_seq)
196
+ risk_probs, miss_km = models["xgboost"].predict(xgb_features)
197
+ triage = classify_urgency(float(risk_probs[0]))
198
+ results["xgboost"] = {
199
+ "risk_probability": float(risk_probs[0]),
200
+ "miss_distance_km": float(miss_km[0]),
201
+ "triage": {
202
+ "tier": triage.tier.value,
203
+ "color": triage.color,
204
+ "recommendation": triage.recommendation,
205
+ },
206
+ }
207
+
208
+ # PI-TFT prediction
209
+ if "pitft" in models:
210
+ risk_prob, miss_log, pc_log10 = _run_pitft_inference(cdm_seq)
211
+ triage = classify_urgency(risk_prob)
212
+ results["pitft"] = {
213
+ "risk_probability": risk_prob,
214
+ "miss_distance_km": float(np.expm1(miss_log)),
215
+ "collision_probability": float(10 ** pc_log10),
216
+ "collision_probability_log10": pc_log10,
217
+ "triage": {
218
+ "tier": triage.tier.value,
219
+ "color": triage.color,
220
+ "recommendation": triage.recommendation,
221
+ },
222
+ }
223
+
224
+ return results
225
+
226
+
227
+ @app.get("/api/model-comparison")
228
+ async def model_comparison():
229
+ """Return pre-computed model comparison results."""
230
+ results = []
231
+
232
+ comparison_path = ROOT / "results" / "model_comparison.json"
233
+ if comparison_path.exists():
234
+ with open(comparison_path) as f:
235
+ results = json.load(f)
236
+
237
+ deep_path = ROOT / "results" / "deep_model_results.json"
238
+ if deep_path.exists():
239
+ with open(deep_path) as f:
240
+ deep = json.load(f)
241
+ pitft_entry = {
242
+ "model": deep["model"],
243
+ **deep["test"],
244
+ }
245
+ results.append(pitft_entry)
246
+
247
+ return results
248
+
249
+
250
+ @app.get("/api/experiment-results")
251
+ async def experiment_results():
252
+ """Return staleness experiment results."""
253
+ exp_path = ROOT / "results" / "staleness_experiment.json"
254
+ if exp_path.exists():
255
+ with open(exp_path) as f:
256
+ return json.load(f)
257
+ return {"error": "No experiment results found. Run: python scripts/run_experiment.py"}
258
+
259
+
260
+ @app.post("/api/bulk-screen")
261
+ async def bulk_screen(request: BulkScreenRequest):
262
+ """Screen TLE pairs for potential conjunctions using orbital filtering."""
263
+ tles = request.tles
264
+ top_k = request.top_k
265
+
266
+ if len(tles) < 2:
267
+ return {"pairs": [], "n_candidates": 0, "n_total": len(tles)}
268
+
269
+ n = len(tles)
270
+ names = [t.get("OBJECT_NAME", f"Object {i}") for i, t in enumerate(tles)]
271
+ norad_ids = [t.get("NORAD_CAT_ID", 0) for t in tles]
272
+
273
+ # Compute altitude from mean motion: a = (mu / n^2)^(1/3), alt = a - R_earth
274
+ MU = 398600.4418 # km^3/s^2
275
+ R_EARTH = 6371.0 # km
276
+
277
+ mean_motions = np.array([t.get("MEAN_MOTION", 15.0) for t in tles])
278
+ n_rad = mean_motions * 2 * np.pi / 86400.0
279
+ n_rad = np.clip(n_rad, 1e-10, None)
280
+ sma = (MU / (n_rad ** 2)) ** (1.0 / 3.0)
281
+
282
+ eccentricities = np.array([t.get("ECCENTRICITY", 0.0) for t in tles])
283
+ apogee = sma * (1 + eccentricities) - R_EARTH
284
+ perigee = sma * (1 - eccentricities) - R_EARTH
285
+
286
+ raan = np.array([t.get("RA_OF_ASC_NODE", 0.0) for t in tles])
287
+
288
+ # Pairwise filtering via broadcasting
289
+ alt_overlap = ((apogee[:, None] >= perigee[None, :]) &
290
+ (apogee[None, :] >= perigee[:, None]))
291
+
292
+ raan_diff = np.abs(raan[:, None] - raan[None, :])
293
+ raan_diff = np.minimum(raan_diff, 360.0 - raan_diff)
294
+ raan_close = raan_diff < 30.0
295
+
296
+ candidates = alt_overlap & raan_close
297
+ np.fill_diagonal(candidates, False)
298
+ candidates = np.triu(candidates, k=1)
299
+
300
+ pairs_i, pairs_j = np.where(candidates)
301
+
302
+ if len(pairs_i) == 0:
303
+ return {"pairs": [], "n_candidates": 0, "n_total": n}
304
+
305
+ # Score candidates using baseline model
306
+ if "baseline" in models:
307
+ pair_altitudes = (apogee[pairs_i] + apogee[pairs_j]) / 2.0
308
+ risk_scores, miss_estimates = models["baseline"].predict(pair_altitudes)
309
+ else:
310
+ risk_scores = np.ones(len(pairs_i)) * 0.5
311
+ miss_estimates = np.zeros(len(pairs_i))
312
+
313
+ top_indices = np.argsort(-risk_scores)[:top_k]
314
+
315
+ result_pairs = []
316
+ for idx in top_indices:
317
+ i, j = int(pairs_i[idx]), int(pairs_j[idx])
318
+ result_pairs.append({
319
+ "name_1": names[i],
320
+ "name_2": names[j],
321
+ "norad_1": norad_ids[i],
322
+ "norad_2": norad_ids[j],
323
+ "risk_score": float(risk_scores[idx]),
324
+ "altitude_km": float((apogee[i] + apogee[j]) / 2),
325
+ "miss_estimate_km": (float(np.expm1(miss_estimates[idx]))
326
+ if miss_estimates[idx] > 0 else 0.0),
327
+ })
328
+
329
+ return {
330
+ "pairs": result_pairs,
331
+ "n_candidates": int(len(pairs_i)),
332
+ "n_total": n,
333
+ }
334
+
335
+
336
+ # --- Helper functions ---
337
+
338
+ def _build_xgboost_features(cdm_sequence: list[dict]) -> np.ndarray:
339
+ """Build XGBoost feature vector from a CDM sequence (dict format).
340
+
341
+ Replicates events_to_flat_features() logic for a single event.
342
+ """
343
+ last = cdm_sequence[-1]
344
+
345
+ exclude = {"event_id", "time_to_tca", "risk", "mission_id"}
346
+ feature_keys = sorted([
347
+ k for k in last.keys()
348
+ if isinstance(last.get(k), (int, float)) and k not in exclude
349
+ ])
350
+
351
+ base = np.array([float(last.get(k, 0.0)) for k in feature_keys], dtype=np.float32)
352
+
353
+ miss_values = np.array([float(s.get("miss_distance", 0.0)) for s in cdm_sequence])
354
+ risk_values = np.array([float(s.get("risk", -10.0)) for s in cdm_sequence])
355
+ tca_values = np.array([float(s.get("time_to_tca", 0.0)) for s in cdm_sequence])
356
+
357
+ n_cdms = len(cdm_sequence)
358
+ miss_mean = float(np.mean(miss_values))
359
+ miss_std = float(np.std(miss_values)) if n_cdms > 1 else 0.0
360
+
361
+ miss_trend = 0.0
362
+ if n_cdms > 1 and np.std(tca_values) > 0:
363
+ miss_trend = float(np.polyfit(tca_values, miss_values, 1)[0])
364
+
365
+ risk_trend = 0.0
366
+ if n_cdms > 1 and np.std(tca_values) > 0:
367
+ risk_trend = float(np.polyfit(tca_values, risk_values, 1)[0])
368
+
369
+ temporal_feats = np.array([
370
+ n_cdms,
371
+ miss_mean,
372
+ miss_std,
373
+ miss_trend,
374
+ risk_trend,
375
+ float(miss_values[0] - miss_values[-1]) if n_cdms > 1 else 0.0,
376
+ float(last.get("time_to_tca", 0.0)),
377
+ float(last.get("relative_speed", 0.0)),
378
+ ], dtype=np.float32)
379
+
380
+ combined = np.concatenate([base, temporal_feats])
381
+ combined = np.nan_to_num(combined, nan=0.0, posinf=0.0, neginf=0.0)
382
+ X = combined.reshape(1, -1)
383
+
384
+ # Pad features if model was trained on augmented data with more columns
385
+ if "xgboost" in models:
386
+ expected = models["xgboost"].scaler.n_features_in_
387
+ if X.shape[1] < expected:
388
+ padding = np.zeros((X.shape[0], expected - X.shape[1]), dtype=X.dtype)
389
+ X = np.hstack([X, padding])
390
+ elif X.shape[1] > expected:
391
+ X = X[:, :expected]
392
+
393
+ return X
394
+
395
+
396
+ def _run_pitft_inference(cdm_sequence: list[dict]) -> tuple[float, float, float]:
397
+ """Run PI-TFT inference on a single CDM sequence.
398
+
399
+ Returns: (risk_probability, miss_log)
400
+ """
401
+ checkpoint = models["pitft_checkpoint"]
402
+ device = models["pitft_device"]
403
+ model = models["pitft"]
404
+ norm = checkpoint["normalization"]
405
+ temperature = checkpoint.get("temperature", 1.0)
406
+ temporal_cols = checkpoint.get("temporal_cols", TEMPORAL_FEATURES)
407
+ static_cols = checkpoint.get("static_cols", STATIC_FEATURES)
408
+
409
+ # Extract temporal features: (S, F_t)
410
+ temporal = np.array([
411
+ [float(cdm.get(col, 0.0)) for col in temporal_cols]
412
+ for cdm in cdm_sequence
413
+ ], dtype=np.float32)
414
+ temporal = np.nan_to_num(temporal, nan=0.0, posinf=0.0, neginf=0.0)
415
+
416
+ # Compute deltas
417
+ if len(temporal) > 1:
418
+ deltas = np.diff(temporal, axis=0)
419
+ deltas = np.concatenate(
420
+ [np.zeros((1, deltas.shape[1]), dtype=np.float32), deltas], axis=0
421
+ )
422
+ else:
423
+ deltas = np.zeros_like(temporal)
424
+
425
+ # Normalize
426
+ t_mean = np.array(norm["temporal_mean"], dtype=np.float32)
427
+ t_std = np.array(norm["temporal_std"], dtype=np.float32)
428
+ d_mean = np.array(norm["delta_mean"], dtype=np.float32)
429
+ d_std = np.array(norm["delta_std"], dtype=np.float32)
430
+ s_mean = np.array(norm["static_mean"], dtype=np.float32)
431
+ s_std = np.array(norm["static_std"], dtype=np.float32)
432
+
433
+ temporal = (temporal - t_mean) / t_std
434
+ deltas = (deltas - d_mean) / d_std
435
+ temporal = np.concatenate([temporal, deltas], axis=1)
436
+
437
+ # Static features from last CDM
438
+ last_cdm = cdm_sequence[-1]
439
+ static = np.array(
440
+ [float(last_cdm.get(col, 0.0)) for col in static_cols], dtype=np.float32
441
+ )
442
+ static = np.nan_to_num(static, nan=0.0, posinf=0.0, neginf=0.0)
443
+ static = (static - s_mean) / s_std
444
+
445
+ # Time-to-TCA
446
+ tca_mean = norm["tca_mean"]
447
+ tca_std = norm["tca_std"]
448
+ tca = np.array(
449
+ [float(cdm.get("time_to_tca", 0.0)) for cdm in cdm_sequence], dtype=np.float32
450
+ ).reshape(-1, 1)
451
+ tca = (tca - tca_mean) / tca_std
452
+
453
+ # Pad/truncate to MAX_SEQ_LEN
454
+ seq_len = len(temporal)
455
+ if seq_len > MAX_SEQ_LEN:
456
+ temporal = temporal[-MAX_SEQ_LEN:]
457
+ tca = tca[-MAX_SEQ_LEN:]
458
+ seq_len = MAX_SEQ_LEN
459
+
460
+ pad_len = MAX_SEQ_LEN - seq_len
461
+ if pad_len > 0:
462
+ temporal = np.pad(temporal, ((pad_len, 0), (0, 0)), constant_values=0)
463
+ tca = np.pad(tca, ((pad_len, 0), (0, 0)), constant_values=0)
464
+
465
+ mask = np.zeros(MAX_SEQ_LEN, dtype=bool)
466
+ mask[pad_len:] = True
467
+
468
+ # Convert to tensors
469
+ temporal_t = torch.tensor(temporal, dtype=torch.float32).unsqueeze(0).to(device)
470
+ static_t = torch.tensor(static, dtype=torch.float32).unsqueeze(0).to(device)
471
+ tca_t = torch.tensor(tca, dtype=torch.float32).unsqueeze(0).to(device)
472
+ mask_t = torch.tensor(mask, dtype=torch.bool).unsqueeze(0).to(device)
473
+
474
+ with torch.no_grad():
475
+ risk_logit, miss_log, pc_log10, _ = model(temporal_t, static_t, tca_t, mask_t)
476
+
477
+ risk_prob = float(torch.sigmoid(risk_logit / temperature).cpu().item())
478
+ miss_log_val = float(miss_log.cpu().item())
479
+ pc_log10_val = float(pc_log10.cpu().item())
480
+
481
+ return risk_prob, miss_log_val, pc_log10_val
app_wrapper.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Startup wrapper for HuggingFace Spaces deployment.
2
+
3
+ Downloads models from DTanzillo/panacea-models on first run,
4
+ then starts the FastAPI application on port 7860.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import shutil
10
+ from pathlib import Path
11
+
12
+ # Ensure the app root is on the Python path
13
+ ROOT = Path(__file__).parent
14
+ sys.path.insert(0, str(ROOT))
15
+
16
+
17
+ def download_models():
18
+ """Download models from HuggingFace Hub if not present locally."""
19
+ model_dir = ROOT / "models"
20
+ results_dir = ROOT / "results"
21
+ model_dir.mkdir(exist_ok=True)
22
+ results_dir.mkdir(exist_ok=True)
23
+
24
+ # Check if models already exist
25
+ needed_files = ["baseline.json", "xgboost.pkl", "transformer.pt"]
26
+ all_present = all((model_dir / f).exists() for f in needed_files)
27
+
28
+ if all_present:
29
+ print("Models already present, skipping download.")
30
+ return
31
+
32
+ print("Downloading models from DTanzillo/panacea-models ...")
33
+ try:
34
+ from huggingface_hub import snapshot_download
35
+
36
+ token = os.environ.get("HF_TOKEN")
37
+ local = Path(snapshot_download(
38
+ "DTanzillo/panacea-models",
39
+ token=token,
40
+ allow_patterns=["models/*", "results/*"],
41
+ ))
42
+
43
+ # Copy model files
44
+ hf_models = local / "models"
45
+ if hf_models.exists():
46
+ for src_file in hf_models.iterdir():
47
+ dst_file = model_dir / src_file.name
48
+ if not dst_file.exists():
49
+ shutil.copy2(src_file, dst_file)
50
+ print(f" Copied {src_file.name}")
51
+
52
+ # Copy result files (only if missing)
53
+ hf_results = local / "results"
54
+ if hf_results.exists():
55
+ for src_file in hf_results.iterdir():
56
+ dst_file = results_dir / src_file.name
57
+ if not dst_file.exists():
58
+ shutil.copy2(src_file, dst_file)
59
+ print(f" Copied result: {src_file.name}")
60
+
61
+ print("Model download complete.")
62
+ except Exception as e:
63
+ print(f"WARNING: Model download failed: {e}")
64
+ print("The API will start but models may not be available.")
65
+
66
+
67
+ if __name__ == "__main__":
68
+ # Step 1: Download models
69
+ download_models()
70
+
71
+ # Step 2: Start uvicorn
72
+ import uvicorn
73
+ port = int(os.environ.get("PORT", 7860))
74
+ print(f"Starting Panacea API on port {port} ...")
75
+ uvicorn.run(
76
+ "app.main:app",
77
+ host="0.0.0.0",
78
+ port=port,
79
+ log_level="info",
80
+ )
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.6
2
+ uvicorn[standard]==0.34.0
3
+ xgboost==2.1.4
4
+ scikit-learn==1.6.1
5
+ pandas==2.2.3
6
+ numpy==2.2.2
7
+ huggingface-hub>=0.27.0
results/deep_model_results.json ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": "PI-TFT (Physics-Informed Temporal Fusion Transformer)",
3
+ "best_epoch": 36,
4
+ "training_time_minutes": 13.175001474221547,
5
+ "optimal_threshold": 0.2639383375644684,
6
+ "temperature": 0.6179193258285522,
7
+ "use_density": true,
8
+ "test": {
9
+ "loss": 0.021245601093944383,
10
+ "auc_pr": 0.5076785607710974,
11
+ "auc_roc": 0.946749355627952,
12
+ "f1_at_50": 0.0,
13
+ "n_positive": 73,
14
+ "n_total": 2167,
15
+ "pos_rate": 0.03368712589144707,
16
+ "f1": 0.5185185137773299,
17
+ "optimal_threshold": 0.2639383375644684,
18
+ "threshold": 0.2639383375644684,
19
+ "recall_at_prec_30": 0.7808219178082192,
20
+ "recall_at_prec_50": 0.4931506849315068,
21
+ "recall_at_prec_70": 0.2876712328767123,
22
+ "mae_log": 0.10174570232629776,
23
+ "rmse_log": 0.15394317551905587,
24
+ "mae_km": 1533.616943359375,
25
+ "median_abs_error_km": 926.875
26
+ },
27
+ "test_calibrated": {
28
+ "auc_pr": 0.5076785607710974,
29
+ "auc_roc": 0.946749355627952,
30
+ "f1_at_50": 0.0,
31
+ "n_positive": 73,
32
+ "n_total": 2167,
33
+ "pos_rate": 0.03368712589144707,
34
+ "f1": 0.5185185137773299,
35
+ "optimal_threshold": 0.15979407727718353,
36
+ "threshold": 0.15979407727718353,
37
+ "recall_at_prec_30": 0.7808219178082192,
38
+ "recall_at_prec_50": 0.4931506849315068,
39
+ "recall_at_prec_70": 0.2876712328767123
40
+ },
41
+ "history": [
42
+ {
43
+ "epoch": 1,
44
+ "train_loss": 6.801268232190931,
45
+ "val_loss": 5.25680010659354,
46
+ "val_auc_pr": 0.007896529454622946,
47
+ "val_f1": 0.019323671026161646,
48
+ "val_mae_log": 7.11151123046875
49
+ },
50
+ {
51
+ "epoch": 2,
52
+ "train_loss": 3.834329532932591,
53
+ "val_loss": 2.7224787643977573,
54
+ "val_auc_pr": 0.010594418921027337,
55
+ "val_f1": 0.023529411193771638,
56
+ "val_mae_log": 5.041237831115723
57
+ },
58
+ {
59
+ "epoch": 3,
60
+ "train_loss": 1.955074778118649,
61
+ "val_loss": 1.1283516032355172,
62
+ "val_auc_pr": 0.008480584727743306,
63
+ "val_f1": 0.021505376131344667,
64
+ "val_mae_log": 3.112034797668457
65
+ },
66
+ {
67
+ "epoch": 4,
68
+ "train_loss": 0.6309667991625296,
69
+ "val_loss": 0.2000983421291624,
70
+ "val_auc_pr": 0.047413803659580166,
71
+ "val_f1": 0.11764705467128042,
72
+ "val_mae_log": 1.1961653232574463
73
+ },
74
+ {
75
+ "epoch": 5,
76
+ "train_loss": 0.13499785540877163,
77
+ "val_loss": 0.02656353052173342,
78
+ "val_auc_pr": 0.05766442486817594,
79
+ "val_f1": 0.15999999680000007,
80
+ "val_mae_log": 0.29869771003723145
81
+ },
82
+ {
83
+ "epoch": 6,
84
+ "train_loss": 0.07689017317182309,
85
+ "val_loss": 0.02750414184161595,
86
+ "val_auc_pr": 0.134885373440643,
87
+ "val_f1": 0.27272726921487606,
88
+ "val_mae_log": 0.3075650930404663
89
+ },
90
+ {
91
+ "epoch": 7,
92
+ "train_loss": 0.08175783813805193,
93
+ "val_loss": 0.07211375555821828,
94
+ "val_auc_pr": 0.18529914529914526,
95
+ "val_f1": 0.4285714239795918,
96
+ "val_mae_log": 0.6812126040458679
97
+ },
98
+ {
99
+ "epoch": 8,
100
+ "train_loss": 0.07750273872468923,
101
+ "val_loss": 0.027415024914911816,
102
+ "val_auc_pr": 0.13237697916045849,
103
+ "val_f1": 0.3157894698060942,
104
+ "val_mae_log": 0.35104697942733765
105
+ },
106
+ {
107
+ "epoch": 9,
108
+ "train_loss": 0.06653158048520218,
109
+ "val_loss": 0.01911477212394987,
110
+ "val_auc_pr": 0.20693184703085693,
111
+ "val_f1": 0.374999995703125,
112
+ "val_mae_log": 0.2960411608219147
113
+ },
114
+ {
115
+ "epoch": 10,
116
+ "train_loss": 0.0626621154917253,
117
+ "val_loss": 0.020604882389307022,
118
+ "val_auc_pr": 0.3348872180451128,
119
+ "val_f1": 0.5454545404958678,
120
+ "val_mae_log": 0.23688556253910065
121
+ },
122
+ {
123
+ "epoch": 11,
124
+ "train_loss": 0.0617836594581604,
125
+ "val_loss": 0.012763384197439467,
126
+ "val_auc_pr": 0.1294155844155844,
127
+ "val_f1": 0.22222221920438956,
128
+ "val_mae_log": 0.1817978173494339
129
+ },
130
+ {
131
+ "epoch": 12,
132
+ "train_loss": 0.05554375463240856,
133
+ "val_loss": 0.01185049262962171,
134
+ "val_auc_pr": 0.24263038548752833,
135
+ "val_f1": 0.36363635867768596,
136
+ "val_mae_log": 0.15147316455841064
137
+ },
138
+ {
139
+ "epoch": 13,
140
+ "train_loss": 0.05319682077781574,
141
+ "val_loss": 0.017937806567975452,
142
+ "val_auc_pr": 0.2786109128966272,
143
+ "val_f1": 0.33333333055555564,
144
+ "val_mae_log": 0.21772687137126923
145
+ },
146
+ {
147
+ "epoch": 14,
148
+ "train_loss": 0.05603743799634882,
149
+ "val_loss": 0.012255215285612004,
150
+ "val_auc_pr": 0.1654839208410637,
151
+ "val_f1": 0.3076923029585799,
152
+ "val_mae_log": 0.12889182567596436
153
+ },
154
+ {
155
+ "epoch": 15,
156
+ "train_loss": 0.052231158416818926,
157
+ "val_loss": 0.008827194571495056,
158
+ "val_auc_pr": 0.30569487983281085,
159
+ "val_f1": 0.4705882311418686,
160
+ "val_mae_log": 0.11871597170829773
161
+ },
162
+ {
163
+ "epoch": 16,
164
+ "train_loss": 0.050459702796227225,
165
+ "val_loss": 0.006688231070126806,
166
+ "val_auc_pr": 0.3174495864073329,
167
+ "val_f1": 0.33333333055555564,
168
+ "val_mae_log": 0.11670727282762527
169
+ },
170
+ {
171
+ "epoch": 17,
172
+ "train_loss": 0.05048987201943591,
173
+ "val_loss": 0.012136828287371568,
174
+ "val_auc_pr": 0.209023569023569,
175
+ "val_f1": 0.3529411723183391,
176
+ "val_mae_log": 0.15395033359527588
177
+ },
178
+ {
179
+ "epoch": 18,
180
+ "train_loss": 0.05087649694367035,
181
+ "val_loss": 0.007568871269800833,
182
+ "val_auc_pr": 0.2673856209150327,
183
+ "val_f1": 0.3999999962500001,
184
+ "val_mae_log": 0.1411171853542328
185
+ },
186
+ {
187
+ "epoch": 19,
188
+ "train_loss": 0.050642090935159374,
189
+ "val_loss": 0.0066412134495164666,
190
+ "val_auc_pr": 0.27475908192734455,
191
+ "val_f1": 0.3999999955555556,
192
+ "val_mae_log": 0.0915408581495285
193
+ },
194
+ {
195
+ "epoch": 20,
196
+ "train_loss": 0.04991532632628003,
197
+ "val_loss": 0.0055730888686542,
198
+ "val_auc_pr": 0.24940384615384617,
199
+ "val_f1": 0.33333332932098775,
200
+ "val_mae_log": 0.10347151011228561
201
+ },
202
+ {
203
+ "epoch": 21,
204
+ "train_loss": 0.049406778288854133,
205
+ "val_loss": 0.008397463309977735,
206
+ "val_auc_pr": 0.22877207681961503,
207
+ "val_f1": 0.2857142816326531,
208
+ "val_mae_log": 0.15620921552181244
209
+ },
210
+ {
211
+ "epoch": 22,
212
+ "train_loss": 0.04929839575008766,
213
+ "val_loss": 0.0075396452365177015,
214
+ "val_auc_pr": 0.3359158185268243,
215
+ "val_f1": 0.33333333055555564,
216
+ "val_mae_log": 0.11639901250600815
217
+ },
218
+ {
219
+ "epoch": 23,
220
+ "train_loss": 0.04896112705606061,
221
+ "val_loss": 0.007832049591732877,
222
+ "val_auc_pr": 0.3431446821152704,
223
+ "val_f1": 0.36363636012396694,
224
+ "val_mae_log": 0.10894307494163513
225
+ },
226
+ {
227
+ "epoch": 24,
228
+ "train_loss": 0.048813931744646384,
229
+ "val_loss": 0.0061542981836412635,
230
+ "val_auc_pr": 0.3559577677224736,
231
+ "val_f1": 0.36363636012396694,
232
+ "val_mae_log": 0.07847719639539719
233
+ },
234
+ {
235
+ "epoch": 25,
236
+ "train_loss": 0.04768835706888019,
237
+ "val_loss": 0.006223144009709358,
238
+ "val_auc_pr": 0.3659761291340239,
239
+ "val_f1": 0.421052627700831,
240
+ "val_mae_log": 0.14390207827091217
241
+ },
242
+ {
243
+ "epoch": 26,
244
+ "train_loss": 0.04840076712740434,
245
+ "val_loss": 0.0067752449374113765,
246
+ "val_auc_pr": 0.2586657651566374,
247
+ "val_f1": 0.34782608355387534,
248
+ "val_mae_log": 0.1449323147535324
249
+ },
250
+ {
251
+ "epoch": 27,
252
+ "train_loss": 0.047609428044509246,
253
+ "val_loss": 0.0065139371103474075,
254
+ "val_auc_pr": 0.34384112619406737,
255
+ "val_f1": 0.34782608355387534,
256
+ "val_mae_log": 0.09073375165462494
257
+ },
258
+ {
259
+ "epoch": 28,
260
+ "train_loss": 0.04662630880201185,
261
+ "val_loss": 0.006256445976240295,
262
+ "val_auc_pr": 0.33832141293241863,
263
+ "val_f1": 0.33333333055555564,
264
+ "val_mae_log": 0.07596895098686218
265
+ },
266
+ {
267
+ "epoch": 29,
268
+ "train_loss": 0.04634691820152708,
269
+ "val_loss": 0.005017333896830678,
270
+ "val_auc_pr": 0.336514012303486,
271
+ "val_f1": 0.33333333055555564,
272
+ "val_mae_log": 0.07677556574344635
273
+ },
274
+ {
275
+ "epoch": 30,
276
+ "train_loss": 0.04663669626052315,
277
+ "val_loss": 0.004762223763723991,
278
+ "val_auc_pr": 0.24682988580047405,
279
+ "val_f1": 0.36363636012396694,
280
+ "val_mae_log": 0.08992886543273926
281
+ },
282
+ {
283
+ "epoch": 31,
284
+ "train_loss": 0.046282403110652355,
285
+ "val_loss": 0.003826435888186097,
286
+ "val_auc_pr": 0.2284485407066052,
287
+ "val_f1": 0.3999999962500001,
288
+ "val_mae_log": 0.06141701713204384
289
+ },
290
+ {
291
+ "epoch": 32,
292
+ "train_loss": 0.04575154318197353,
293
+ "val_loss": 0.005115043604746461,
294
+ "val_auc_pr": 0.3611255411255411,
295
+ "val_f1": 0.3999999962500001,
296
+ "val_mae_log": 0.09008380770683289
297
+ },
298
+ {
299
+ "epoch": 33,
300
+ "train_loss": 0.046043931763317135,
301
+ "val_loss": 0.004483342935730304,
302
+ "val_auc_pr": 0.36333333333333334,
303
+ "val_f1": 0.3809523773242631,
304
+ "val_mae_log": 0.10232321172952652
305
+ },
306
+ {
307
+ "epoch": 34,
308
+ "train_loss": 0.04492839058307377,
309
+ "val_loss": 0.007276699944798436,
310
+ "val_auc_pr": 0.3461904761904762,
311
+ "val_f1": 0.3809523773242631,
312
+ "val_mae_log": 0.10686437785625458
313
+ },
314
+ {
315
+ "epoch": 35,
316
+ "train_loss": 0.04576677558188503,
317
+ "val_loss": 0.004259714224774923,
318
+ "val_auc_pr": 0.37718954248366016,
319
+ "val_f1": 0.3999999962500001,
320
+ "val_mae_log": 0.0769796371459961
321
+ },
322
+ {
323
+ "epoch": 36,
324
+ "train_loss": 0.044130372638637956,
325
+ "val_loss": 0.004274079659288483,
326
+ "val_auc_pr": 0.4215151515151515,
327
+ "val_f1": 0.4444444395061729,
328
+ "val_mae_log": 0.09318451583385468
329
+ },
330
+ {
331
+ "epoch": 37,
332
+ "train_loss": 0.04556343443691731,
333
+ "val_loss": 0.0053521015548280305,
334
+ "val_auc_pr": 0.3828373015873016,
335
+ "val_f1": 0.421052627700831,
336
+ "val_mae_log": 0.11446798592805862
337
+ },
338
+ {
339
+ "epoch": 38,
340
+ "train_loss": 0.04497031863476779,
341
+ "val_loss": 0.005016647595246988,
342
+ "val_auc_pr": 0.38186813186813184,
343
+ "val_f1": 0.3809523773242631,
344
+ "val_mae_log": 0.11497646570205688
345
+ },
346
+ {
347
+ "epoch": 39,
348
+ "train_loss": 0.04312905277553442,
349
+ "val_loss": 0.003749881671475513,
350
+ "val_auc_pr": 0.3595238095238095,
351
+ "val_f1": 0.3809523773242631,
352
+ "val_mae_log": 0.05548140034079552
353
+ },
354
+ {
355
+ "epoch": 40,
356
+ "train_loss": 0.04352163130769859,
357
+ "val_loss": 0.005372332009885993,
358
+ "val_auc_pr": 0.3503288825869471,
359
+ "val_f1": 0.34782608355387534,
360
+ "val_mae_log": 0.08230870962142944
361
+ },
362
+ {
363
+ "epoch": 41,
364
+ "train_loss": 0.043740846146200156,
365
+ "val_loss": 0.0039979582319834405,
366
+ "val_auc_pr": 0.41458333333333336,
367
+ "val_f1": 0.3999999962500001,
368
+ "val_mae_log": 0.08734633028507233
369
+ },
370
+ {
371
+ "epoch": 42,
372
+ "train_loss": 0.04409235781310378,
373
+ "val_loss": 0.005109895303446267,
374
+ "val_auc_pr": 0.2524756335282651,
375
+ "val_f1": 0.33333333003472226,
376
+ "val_mae_log": 0.07870446890592575
377
+ },
378
+ {
379
+ "epoch": 43,
380
+ "train_loss": 0.043179894389735685,
381
+ "val_loss": 0.005041864268215639,
382
+ "val_auc_pr": 0.26508912655971484,
383
+ "val_f1": 0.36363636012396694,
384
+ "val_mae_log": 0.07578516006469727
385
+ },
386
+ {
387
+ "epoch": 44,
388
+ "train_loss": 0.04234155755792115,
389
+ "val_loss": 0.0038543779269925187,
390
+ "val_auc_pr": 0.3427519893899204,
391
+ "val_f1": 0.33333333055555564,
392
+ "val_mae_log": 0.06378159672021866
393
+ },
394
+ {
395
+ "epoch": 45,
396
+ "train_loss": 0.043199574021068776,
397
+ "val_loss": 0.00448337330349854,
398
+ "val_auc_pr": 0.38693977591036416,
399
+ "val_f1": 0.36363636012396694,
400
+ "val_mae_log": 0.08112290501594543
401
+ },
402
+ {
403
+ "epoch": 46,
404
+ "train_loss": 0.04324697579282361,
405
+ "val_loss": 0.004593804511906845,
406
+ "val_auc_pr": 0.3657142857142857,
407
+ "val_f1": 0.3809523773242631,
408
+ "val_mae_log": 0.12126877903938293
409
+ },
410
+ {
411
+ "epoch": 47,
412
+ "train_loss": 0.042983541144309814,
413
+ "val_loss": 0.0034202520120223717,
414
+ "val_auc_pr": 0.36703703703703705,
415
+ "val_f1": 0.3809523773242631,
416
+ "val_mae_log": 0.05318637564778328
417
+ },
418
+ {
419
+ "epoch": 48,
420
+ "train_loss": 0.04088504479543583,
421
+ "val_loss": 0.0037384599480511887,
422
+ "val_auc_pr": 0.35812684047978166,
423
+ "val_f1": 0.38461538150887575,
424
+ "val_mae_log": 0.0607416033744812
425
+ },
426
+ {
427
+ "epoch": 49,
428
+ "train_loss": 0.0411647165143812,
429
+ "val_loss": 0.0038923417118244936,
430
+ "val_auc_pr": 0.37444444444444447,
431
+ "val_f1": 0.3809523773242631,
432
+ "val_mae_log": 0.07454186677932739
433
+ },
434
+ {
435
+ "epoch": 50,
436
+ "train_loss": 0.04235347539589212,
437
+ "val_loss": 0.0035431724141484927,
438
+ "val_auc_pr": 0.3718181818181818,
439
+ "val_f1": 0.3809523773242631,
440
+ "val_mae_log": 0.05186235159635544
441
+ },
442
+ {
443
+ "epoch": 51,
444
+ "train_loss": 0.03975096909782371,
445
+ "val_loss": 0.003855357279202768,
446
+ "val_auc_pr": 0.37,
447
+ "val_f1": 0.3809523773242631,
448
+ "val_mae_log": 0.08433445543050766
449
+ },
450
+ {
451
+ "epoch": 52,
452
+ "train_loss": 0.040304526777283564,
453
+ "val_loss": 0.003954493274380054,
454
+ "val_auc_pr": 0.36705882352941177,
455
+ "val_f1": 0.36363636012396694,
456
+ "val_mae_log": 0.0650041252374649
457
+ },
458
+ {
459
+ "epoch": 53,
460
+ "train_loss": 0.041316902365636184,
461
+ "val_loss": 0.0044658422370308214,
462
+ "val_auc_pr": 0.37444444444444447,
463
+ "val_f1": 0.39999999680000003,
464
+ "val_mae_log": 0.08514165133237839
465
+ },
466
+ {
467
+ "epoch": 54,
468
+ "train_loss": 0.041085500773545856,
469
+ "val_loss": 0.003584100299381784,
470
+ "val_auc_pr": 0.36991596638655466,
471
+ "val_f1": 0.36363636012396694,
472
+ "val_mae_log": 0.04943912476301193
473
+ },
474
+ {
475
+ "epoch": 55,
476
+ "train_loss": 0.04048956327543066,
477
+ "val_loss": 0.003669723236401166,
478
+ "val_auc_pr": 0.366961926961927,
479
+ "val_f1": 0.34782608355387534,
480
+ "val_mae_log": 0.0743192732334137
481
+ },
482
+ {
483
+ "epoch": 56,
484
+ "train_loss": 0.04016674624101536,
485
+ "val_loss": 0.004304527521266469,
486
+ "val_auc_pr": 0.3745588235294118,
487
+ "val_f1": 0.39999999680000003,
488
+ "val_mae_log": 0.08440288156270981
489
+ }
490
+ ],
491
+ "conformal": {
492
+ "alpha_0.01": {
493
+ "conformal_metrics": {
494
+ "alpha": 0.01,
495
+ "target_coverage": 0.99,
496
+ "marginal_coverage": 0.9700046146746655,
497
+ "coverage_guarantee_met": false,
498
+ "avg_set_size": 2.1033687125057683,
499
+ "efficiency": 0.4741578218735579,
500
+ "positive_coverage": 0.136986301369863,
501
+ "negative_coverage": 0.9990448901623686,
502
+ "set_size_distribution": {
503
+ "2": 1948,
504
+ "3": 214,
505
+ "4": 5
506
+ },
507
+ "n_test": 2167,
508
+ "mean_interval_width": 0.35249775648117065,
509
+ "median_interval_width": 0.3299492597579956
510
+ },
511
+ "conformal_state": {
512
+ "is_calibrated": true,
513
+ "alpha": 0.01,
514
+ "q_hat": 0.31530878875241947,
515
+ "q_residual": 0.31530878875241947,
516
+ "n_cal": 527,
517
+ "tiers": {
518
+ "LOW": [
519
+ 0.0,
520
+ 0.1
521
+ ],
522
+ "MODERATE": [
523
+ 0.1,
524
+ 0.4
525
+ ],
526
+ "HIGH": [
527
+ 0.4,
528
+ 0.7
529
+ ],
530
+ "CRITICAL": [
531
+ 0.7,
532
+ 1.0
533
+ ]
534
+ }
535
+ }
536
+ },
537
+ "alpha_0.05": {
538
+ "conformal_metrics": {
539
+ "alpha": 0.05,
540
+ "target_coverage": 0.95,
541
+ "marginal_coverage": 0.9487771112136595,
542
+ "coverage_guarantee_met": true,
543
+ "avg_set_size": 1.9856945085371482,
544
+ "efficiency": 0.503576372865713,
545
+ "positive_coverage": 0.0,
546
+ "negative_coverage": 0.9818529130850048,
547
+ "set_size_distribution": {
548
+ "1": 31,
549
+ "2": 2136
550
+ },
551
+ "n_test": 2167,
552
+ "mean_interval_width": 0.14139389991760254,
553
+ "median_interval_width": 0.1266784965991974
554
+ },
555
+ "conformal_state": {
556
+ "is_calibrated": true,
557
+ "alpha": 0.05,
558
+ "q_hat": 0.1120380280677236,
559
+ "q_residual": 0.1120380280677236,
560
+ "n_cal": 527,
561
+ "tiers": {
562
+ "LOW": [
563
+ 0.0,
564
+ 0.1
565
+ ],
566
+ "MODERATE": [
567
+ 0.1,
568
+ 0.4
569
+ ],
570
+ "HIGH": [
571
+ 0.4,
572
+ 0.7
573
+ ],
574
+ "CRITICAL": [
575
+ 0.7,
576
+ 1.0
577
+ ]
578
+ }
579
+ }
580
+ },
581
+ "alpha_0.1": {
582
+ "conformal_metrics": {
583
+ "alpha": 0.1,
584
+ "target_coverage": 0.9,
585
+ "marginal_coverage": 0.9284725426857406,
586
+ "coverage_guarantee_met": true,
587
+ "avg_set_size": 1.103830179972312,
588
+ "efficiency": 0.724042455006922,
589
+ "positive_coverage": 0.0,
590
+ "negative_coverage": 0.9608404966571156,
591
+ "set_size_distribution": {
592
+ "1": 1942,
593
+ "2": 225
594
+ },
595
+ "n_test": 2167,
596
+ "mean_interval_width": 0.060726769268512726,
597
+ "median_interval_width": 0.05510023236274719
598
+ },
599
+ "conformal_state": {
600
+ "is_calibrated": true,
601
+ "alpha": 0.1,
602
+ "q_hat": 0.04045976169647709,
603
+ "q_residual": 0.04045976169647709,
604
+ "n_cal": 527,
605
+ "tiers": {
606
+ "LOW": [
607
+ 0.0,
608
+ 0.1
609
+ ],
610
+ "MODERATE": [
611
+ 0.1,
612
+ 0.4
613
+ ],
614
+ "HIGH": [
615
+ 0.4,
616
+ 0.7
617
+ ],
618
+ "CRITICAL": [
619
+ 0.7,
620
+ 1.0
621
+ ]
622
+ }
623
+ }
624
+ },
625
+ "alpha_0.2": {
626
+ "conformal_metrics": {
627
+ "alpha": 0.2,
628
+ "target_coverage": 0.8,
629
+ "marginal_coverage": 0.9220119981541302,
630
+ "coverage_guarantee_met": true,
631
+ "avg_set_size": 1.054453161052146,
632
+ "efficiency": 0.7363867097369635,
633
+ "positive_coverage": 0.0,
634
+ "negative_coverage": 0.9541547277936963,
635
+ "set_size_distribution": {
636
+ "1": 2049,
637
+ "2": 118
638
+ },
639
+ "n_test": 2167,
640
+ "mean_interval_width": 0.04071307182312012,
641
+ "median_interval_width": 0.039181869477033615
642
+ },
643
+ "conformal_state": {
644
+ "is_calibrated": true,
645
+ "alpha": 0.2,
646
+ "q_hat": 0.024541400479014954,
647
+ "q_residual": 0.024541400479014954,
648
+ "n_cal": 527,
649
+ "tiers": {
650
+ "LOW": [
651
+ 0.0,
652
+ 0.1
653
+ ],
654
+ "MODERATE": [
655
+ 0.1,
656
+ 0.4
657
+ ],
658
+ "HIGH": [
659
+ 0.4,
660
+ 0.7
661
+ ],
662
+ "CRITICAL": [
663
+ 0.7,
664
+ 1.0
665
+ ]
666
+ }
667
+ }
668
+ }
669
+ }
670
+ }
results/model_comparison.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "model": "Orbital Shell Baseline",
4
+ "auc_pr": 0.061184346220415166,
5
+ "auc_roc": 0.6374507725922728,
6
+ "f1_at_50": 0.0,
7
+ "n_positive": 73,
8
+ "n_total": 2167,
9
+ "pos_rate": 0.03368712505768343,
10
+ "f1": 0.13223140017211957,
11
+ "optimal_threshold": 0.03237410071942446,
12
+ "threshold": 0.03237410071942446,
13
+ "recall_at_prec_30": 0.0,
14
+ "recall_at_prec_50": 0.0,
15
+ "recall_at_prec_70": 0.0,
16
+ "mae_log": 0.9927019602313063,
17
+ "rmse_log": 1.2867684153860748,
18
+ "mae_km": 10600.126897201788,
19
+ "median_abs_error_km": 7222.8428976622645
20
+ },
21
+ {
22
+ "model": "XGBoost (Engineered Features)",
23
+ "auc_pr": 0.9884220304219559,
24
+ "auc_roc": 0.9995944054114168,
25
+ "f1_at_50": 0.9411764705882353,
26
+ "n_positive": 73,
27
+ "n_total": 2167,
28
+ "pos_rate": 0.03368712505768343,
29
+ "f1": 0.9473684160604224,
30
+ "optimal_threshold": 0.5539590716362,
31
+ "threshold": 0.5539590716362,
32
+ "recall_at_prec_30": 1.0,
33
+ "recall_at_prec_50": 1.0,
34
+ "recall_at_prec_70": 1.0,
35
+ "mae_log": 0.011742588180292227,
36
+ "rmse_log": 0.03972278871639667,
37
+ "mae_km": 80.85688587394668,
38
+ "median_abs_error_km": 42.99218749998545
39
+ }
40
+ ]
results/staleness_experiment.json ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cutoffs": [
3
+ 2.0,
4
+ 2.5,
5
+ 3.0,
6
+ 3.5,
7
+ 4.0,
8
+ 5.0,
9
+ 6.0
10
+ ],
11
+ "n_test_events": 2167,
12
+ "n_positive": 73,
13
+ "baseline": [
14
+ {
15
+ "auc_pr": 0.061184346220415166,
16
+ "auc_roc": 0.6374507725922728,
17
+ "f1_at_50": 0.0,
18
+ "n_positive": 73,
19
+ "n_total": 2167,
20
+ "pos_rate": 0.03368712505768343,
21
+ "f1": 0.13223140017211957,
22
+ "optimal_threshold": 0.03237410071942446,
23
+ "threshold": 0.03237410071942446,
24
+ "recall_at_prec_30": 0.0,
25
+ "recall_at_prec_50": 0.0,
26
+ "recall_at_prec_70": 0.0,
27
+ "cutoff": 2.0,
28
+ "n_events": 2167
29
+ },
30
+ {
31
+ "auc_pr": 0.061184346220415166,
32
+ "auc_roc": 0.6374507725922728,
33
+ "f1_at_50": 0.0,
34
+ "n_positive": 73,
35
+ "n_total": 2167,
36
+ "pos_rate": 0.03368712505768343,
37
+ "f1": 0.13223140017211957,
38
+ "optimal_threshold": 0.03237410071942446,
39
+ "threshold": 0.03237410071942446,
40
+ "recall_at_prec_30": 0.0,
41
+ "recall_at_prec_50": 0.0,
42
+ "recall_at_prec_70": 0.0,
43
+ "cutoff": 2.5,
44
+ "n_events": 2167
45
+ },
46
+ {
47
+ "auc_pr": 0.061184346220415166,
48
+ "auc_roc": 0.6374507725922728,
49
+ "f1_at_50": 0.0,
50
+ "n_positive": 73,
51
+ "n_total": 2167,
52
+ "pos_rate": 0.03368712505768343,
53
+ "f1": 0.13223140017211957,
54
+ "optimal_threshold": 0.03237410071942446,
55
+ "threshold": 0.03237410071942446,
56
+ "recall_at_prec_30": 0.0,
57
+ "recall_at_prec_50": 0.0,
58
+ "recall_at_prec_70": 0.0,
59
+ "cutoff": 3.0,
60
+ "n_events": 2167
61
+ },
62
+ {
63
+ "auc_pr": 0.061184346220415166,
64
+ "auc_roc": 0.6374507725922728,
65
+ "f1_at_50": 0.0,
66
+ "n_positive": 73,
67
+ "n_total": 2167,
68
+ "pos_rate": 0.03368712505768343,
69
+ "f1": 0.13223140017211957,
70
+ "optimal_threshold": 0.03237410071942446,
71
+ "threshold": 0.03237410071942446,
72
+ "recall_at_prec_30": 0.0,
73
+ "recall_at_prec_50": 0.0,
74
+ "recall_at_prec_70": 0.0,
75
+ "cutoff": 3.5,
76
+ "n_events": 2167
77
+ },
78
+ {
79
+ "auc_pr": 0.061184346220415166,
80
+ "auc_roc": 0.6374507725922728,
81
+ "f1_at_50": 0.0,
82
+ "n_positive": 73,
83
+ "n_total": 2167,
84
+ "pos_rate": 0.03368712505768343,
85
+ "f1": 0.13223140017211957,
86
+ "optimal_threshold": 0.03237410071942446,
87
+ "threshold": 0.03237410071942446,
88
+ "recall_at_prec_30": 0.0,
89
+ "recall_at_prec_50": 0.0,
90
+ "recall_at_prec_70": 0.0,
91
+ "cutoff": 4.0,
92
+ "n_events": 2167
93
+ },
94
+ {
95
+ "auc_pr": 0.061184346220415166,
96
+ "auc_roc": 0.6374507725922728,
97
+ "f1_at_50": 0.0,
98
+ "n_positive": 73,
99
+ "n_total": 2167,
100
+ "pos_rate": 0.03368712505768343,
101
+ "f1": 0.13223140017211957,
102
+ "optimal_threshold": 0.03237410071942446,
103
+ "threshold": 0.03237410071942446,
104
+ "recall_at_prec_30": 0.0,
105
+ "recall_at_prec_50": 0.0,
106
+ "recall_at_prec_70": 0.0,
107
+ "cutoff": 5.0,
108
+ "n_events": 2167
109
+ },
110
+ {
111
+ "auc_pr": 0.061184346220415166,
112
+ "auc_roc": 0.6374507725922728,
113
+ "f1_at_50": 0.0,
114
+ "n_positive": 73,
115
+ "n_total": 2167,
116
+ "pos_rate": 0.03368712505768343,
117
+ "f1": 0.13223140017211957,
118
+ "optimal_threshold": 0.03237410071942446,
119
+ "threshold": 0.03237410071942446,
120
+ "recall_at_prec_30": 0.0,
121
+ "recall_at_prec_50": 0.0,
122
+ "recall_at_prec_70": 0.0,
123
+ "cutoff": 6.0,
124
+ "n_events": 2167
125
+ }
126
+ ],
127
+ "xgboost": [
128
+ {
129
+ "auc_pr": 0.9883137899600032,
130
+ "auc_roc": 0.9995878635632139,
131
+ "f1_at_50": 0.935064935064935,
132
+ "n_positive": 73,
133
+ "n_total": 2167,
134
+ "pos_rate": 0.03368712505768343,
135
+ "f1": 0.9411764655987015,
136
+ "optimal_threshold": 0.5284891724586487,
137
+ "threshold": 0.5284891724586487,
138
+ "recall_at_prec_30": 1.0,
139
+ "recall_at_prec_50": 1.0,
140
+ "recall_at_prec_70": 1.0,
141
+ "cutoff": 2.0,
142
+ "n_events": 2167
143
+ },
144
+ {
145
+ "auc_pr": 0.9123203140942627,
146
+ "auc_roc": 0.9903418565869928,
147
+ "f1_at_50": 0.8421052631578947,
148
+ "n_positive": 70,
149
+ "n_total": 2126,
150
+ "pos_rate": 0.03292568203198495,
151
+ "f1": 0.8467153234695509,
152
+ "optimal_threshold": 0.9780168533325195,
153
+ "threshold": 0.9780168533325195,
154
+ "recall_at_prec_30": 0.9857142857142858,
155
+ "recall_at_prec_50": 0.9714285714285714,
156
+ "recall_at_prec_70": 0.9285714285714286,
157
+ "cutoff": 2.5,
158
+ "n_events": 2126
159
+ },
160
+ {
161
+ "auc_pr": 0.7112636105696798,
162
+ "auc_roc": 0.9702624390685601,
163
+ "f1_at_50": 0.7012987012987013,
164
+ "n_positive": 67,
165
+ "n_total": 2045,
166
+ "pos_rate": 0.03276283618581907,
167
+ "f1": 0.722222217246335,
168
+ "optimal_threshold": 0.9061354398727417,
169
+ "threshold": 0.9061354398727417,
170
+ "recall_at_prec_30": 0.9104477611940298,
171
+ "recall_at_prec_50": 0.8507462686567164,
172
+ "recall_at_prec_70": 0.7164179104477612,
173
+ "cutoff": 3.0,
174
+ "n_events": 2045
175
+ },
176
+ {
177
+ "auc_pr": 0.7224173760553306,
178
+ "auc_roc": 0.9779084384250436,
179
+ "f1_at_50": 0.6666666666666666,
180
+ "n_positive": 65,
181
+ "n_total": 1962,
182
+ "pos_rate": 0.033129459734964326,
183
+ "f1": 0.6802721039104078,
184
+ "optimal_threshold": 0.8590014576911926,
185
+ "threshold": 0.8590014576911926,
186
+ "recall_at_prec_30": 0.9384615384615385,
187
+ "recall_at_prec_50": 0.8615384615384616,
188
+ "recall_at_prec_70": 0.6153846153846154,
189
+ "cutoff": 3.5,
190
+ "n_events": 1962
191
+ },
192
+ {
193
+ "auc_pr": 0.6392429519999454,
194
+ "auc_roc": 0.9669743064869061,
195
+ "f1_at_50": 0.5921052631578947,
196
+ "n_positive": 62,
197
+ "n_total": 1890,
198
+ "pos_rate": 0.0328042328042328,
199
+ "f1": 0.6370370320702333,
200
+ "optimal_threshold": 0.8714247941970825,
201
+ "threshold": 0.8714247941970825,
202
+ "recall_at_prec_30": 0.8870967741935484,
203
+ "recall_at_prec_50": 0.8064516129032258,
204
+ "recall_at_prec_70": 0.41935483870967744,
205
+ "cutoff": 4.0,
206
+ "n_events": 1890
207
+ },
208
+ {
209
+ "auc_pr": 0.42295193898950256,
210
+ "auc_roc": 0.9482351744481741,
211
+ "f1_at_50": 0.5419354838709678,
212
+ "n_positive": 58,
213
+ "n_total": 1753,
214
+ "pos_rate": 0.03308613804905876,
215
+ "f1": 0.5454545404630832,
216
+ "optimal_threshold": 0.9965507984161377,
217
+ "threshold": 0.9965507984161377,
218
+ "recall_at_prec_30": 0.7931034482758621,
219
+ "recall_at_prec_50": 0.5689655172413793,
220
+ "recall_at_prec_70": 0.0,
221
+ "cutoff": 5.0,
222
+ "n_events": 1753
223
+ },
224
+ {
225
+ "auc_pr": 0.3219032626600778,
226
+ "auc_roc": 0.9162752848174842,
227
+ "f1_at_50": 0.4027777777777778,
228
+ "n_positive": 55,
229
+ "n_total": 1619,
230
+ "pos_rate": 0.033971587399629403,
231
+ "f1": 0.42592592092764064,
232
+ "optimal_threshold": 0.9984425902366638,
233
+ "threshold": 0.9984425902366638,
234
+ "recall_at_prec_30": 0.5818181818181818,
235
+ "recall_at_prec_50": 0.12727272727272726,
236
+ "recall_at_prec_70": 0.01818181818181818,
237
+ "cutoff": 6.0,
238
+ "n_events": 1619
239
+ }
240
+ ],
241
+ "pitft": [
242
+ {
243
+ "auc_pr": 0.5108315323239697,
244
+ "auc_roc": 0.9467689811725608,
245
+ "f1_at_50": 0.0,
246
+ "n_positive": 73,
247
+ "n_total": 2167,
248
+ "pos_rate": 0.03368712505768343,
249
+ "f1": 0.5325443737908337,
250
+ "optimal_threshold": 0.18103967607021332,
251
+ "threshold": 0.18103967607021332,
252
+ "recall_at_prec_30": 0.7808219178082192,
253
+ "recall_at_prec_50": 0.5068493150684932,
254
+ "recall_at_prec_70": 0.2876712328767123,
255
+ "cutoff": 2.0,
256
+ "n_events": 2167
257
+ },
258
+ {
259
+ "auc_pr": 0.40929547300496166,
260
+ "auc_roc": 0.9342620900500278,
261
+ "f1_at_50": 0.028169014084507043,
262
+ "n_positive": 70,
263
+ "n_total": 2126,
264
+ "pos_rate": 0.03292568203198495,
265
+ "f1": 0.45121950730220106,
266
+ "optimal_threshold": 0.18565748631954193,
267
+ "threshold": 0.18565748631954193,
268
+ "recall_at_prec_30": 0.6571428571428571,
269
+ "recall_at_prec_50": 0.35714285714285715,
270
+ "recall_at_prec_70": 0.2,
271
+ "cutoff": 2.5,
272
+ "n_events": 2126
273
+ },
274
+ {
275
+ "auc_pr": 0.3126159912723518,
276
+ "auc_roc": 0.9086669785551514,
277
+ "f1_at_50": 0.056338028169014086,
278
+ "n_positive": 67,
279
+ "n_total": 2045,
280
+ "pos_rate": 0.03276283618581907,
281
+ "f1": 0.3968253918455531,
282
+ "optimal_threshold": 0.2572215497493744,
283
+ "threshold": 0.2572215497493744,
284
+ "recall_at_prec_30": 0.4626865671641791,
285
+ "recall_at_prec_50": 0.208955223880597,
286
+ "recall_at_prec_70": 0.0,
287
+ "cutoff": 3.0,
288
+ "n_events": 2045
289
+ },
290
+ {
291
+ "auc_pr": 0.32548992974654617,
292
+ "auc_roc": 0.9031263939013017,
293
+ "f1_at_50": 0.058823529411764705,
294
+ "n_positive": 65,
295
+ "n_total": 1962,
296
+ "pos_rate": 0.033129459734964326,
297
+ "f1": 0.3716814110423683,
298
+ "optimal_threshold": 0.28492599725723267,
299
+ "threshold": 0.28492599725723267,
300
+ "recall_at_prec_30": 0.46153846153846156,
301
+ "recall_at_prec_50": 0.26153846153846155,
302
+ "recall_at_prec_70": 0.015384615384615385,
303
+ "cutoff": 3.5,
304
+ "n_events": 1962
305
+ },
306
+ {
307
+ "auc_pr": 0.286925285041537,
308
+ "auc_roc": 0.892249594127197,
309
+ "f1_at_50": 0.0,
310
+ "n_positive": 62,
311
+ "n_total": 1890,
312
+ "pos_rate": 0.0328042328042328,
313
+ "f1": 0.3736263691341626,
314
+ "optimal_threshold": 0.16788320243358612,
315
+ "threshold": 0.16788320243358612,
316
+ "recall_at_prec_30": 0.45161290322580644,
317
+ "recall_at_prec_50": 0.22580645161290322,
318
+ "recall_at_prec_70": 0.0,
319
+ "cutoff": 4.0,
320
+ "n_events": 1890
321
+ },
322
+ {
323
+ "auc_pr": 0.23877494536053875,
324
+ "auc_roc": 0.867622825755264,
325
+ "f1_at_50": 0.0625,
326
+ "n_positive": 58,
327
+ "n_total": 1753,
328
+ "pos_rate": 0.03308613804905876,
329
+ "f1": 0.33082706275086216,
330
+ "optimal_threshold": 0.21164827048778534,
331
+ "threshold": 0.21164827048778534,
332
+ "recall_at_prec_30": 0.3103448275862069,
333
+ "recall_at_prec_50": 0.1896551724137931,
334
+ "recall_at_prec_70": 0.0,
335
+ "cutoff": 5.0,
336
+ "n_events": 1753
337
+ },
338
+ {
339
+ "auc_pr": 0.1838323482889146,
340
+ "auc_roc": 0.8097419204836084,
341
+ "f1_at_50": 0.06666666666666667,
342
+ "n_positive": 55,
343
+ "n_total": 1619,
344
+ "pos_rate": 0.033971587399629403,
345
+ "f1": 0.2741935434508325,
346
+ "optimal_threshold": 0.21228547394275665,
347
+ "threshold": 0.21228547394275665,
348
+ "recall_at_prec_30": 0.18181818181818182,
349
+ "recall_at_prec_50": 0.07272727272727272,
350
+ "recall_at_prec_70": 0.0,
351
+ "cutoff": 6.0,
352
+ "n_events": 1619
353
+ }
354
+ ]
355
+ }
src/__init__.py ADDED
File without changes
src/data/__init__.py ADDED
File without changes
src/data/augment.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by Claude Code -- 2026-02-08
2
+ """Data augmentation for the conjunction prediction dataset.
3
+
4
+ The fundamental problem: only 67 high-risk events out of 13,154 in training (0.5%).
5
+ This module provides two augmentation strategies:
6
+
7
+ 1. SPACE-TRACK INTEGRATION: Merge real high-risk CDMs from Space-Track's cdm_public
8
+ feed. These have fewer features (16 vs 103) but provide real positive examples.
9
+
10
+ 2. TIME-SERIES AUGMENTATION: Create synthetic variants of existing high-risk events
11
+ by applying realistic perturbations:
12
+ - Gaussian noise on covariance/position/velocity features
13
+ - Temporal jittering (shift CDM creation times slightly)
14
+ - Feature dropout (randomly zero out some features, simulating missing data)
15
+ - Sequence truncation (remove early CDMs, simulating late detection)
16
+
17
+ Both strategies are physics-aware: they don't generate impossible configurations
18
+ (e.g., negative miss distances or covariance values).
19
+ """
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ from pathlib import Path
24
+
25
+
26
+ def augment_event_noise(
27
+ event_df: pd.DataFrame,
28
+ noise_scale: float = 0.05,
29
+ n_augments: int = 5,
30
+ rng: np.random.Generator = None,
31
+ ) -> list[pd.DataFrame]:
32
+ """
33
+ Create n_augments noisy variants of a single conjunction event.
34
+
35
+ Applies Gaussian noise to numeric features, scaled by each column's
36
+ standard deviation within the event. Preserves event_id structure and
37
+ ensures physical constraints (non-negative distances, etc.).
38
+ """
39
+ if rng is None:
40
+ rng = np.random.default_rng(42)
41
+
42
+ # Identify numeric columns to perturb (exclude IDs and targets)
43
+ exclude = {"event_id", "time_to_tca", "risk", "mission_id", "source"}
44
+ numeric_cols = event_df.select_dtypes(include=[np.number]).columns
45
+ perturb_cols = [c for c in numeric_cols if c not in exclude]
46
+
47
+ augmented = []
48
+ for i in range(n_augments):
49
+ aug = event_df.copy()
50
+
51
+ for col in perturb_cols:
52
+ values = aug[col].values.astype(float)
53
+ col_std = np.std(values)
54
+ if col_std < 1e-10:
55
+ col_std = np.abs(np.mean(values)) * 0.01 + 1e-10
56
+
57
+ noise = rng.normal(0, noise_scale * col_std, size=len(values))
58
+ aug[col] = values + noise
59
+
60
+ # Physical constraints
61
+ if "miss_distance" in aug.columns:
62
+ aug["miss_distance"] = aug["miss_distance"].clip(lower=0)
63
+ if "relative_speed" in aug.columns:
64
+ aug["relative_speed"] = aug["relative_speed"].clip(lower=0)
65
+
66
+ # Ensure covariance sigma columns stay positive
67
+ sigma_cols = [c for c in perturb_cols if "sigma" in c.lower()]
68
+ for col in sigma_cols:
69
+ aug[col] = aug[col].clip(lower=0)
70
+
71
+ augmented.append(aug)
72
+
73
+ return augmented
74
+
75
+
76
+ def augment_event_truncate(
77
+ event_df: pd.DataFrame,
78
+ min_keep: int = 3,
79
+ n_augments: int = 3,
80
+ rng: np.random.Generator = None,
81
+ ) -> list[pd.DataFrame]:
82
+ """
83
+ Create truncated variants by removing early CDMs.
84
+
85
+ Simulates late-detection scenarios where only the most recent CDMs
86
+ are available (closer to TCA).
87
+ """
88
+ if rng is None:
89
+ rng = np.random.default_rng(42)
90
+
91
+ # Sort by time_to_tca descending (first CDM = furthest from TCA)
92
+ event_df = event_df.sort_values("time_to_tca", ascending=False)
93
+ n_cdms = len(event_df)
94
+
95
+ if n_cdms <= min_keep:
96
+ return []
97
+
98
+ augmented = []
99
+ for _ in range(n_augments):
100
+ # Keep between min_keep and n_cdms-1 CDMs (always keep the last few)
101
+ n_keep = rng.integers(min_keep, n_cdms)
102
+ aug = event_df.iloc[-n_keep:].copy()
103
+ augmented.append(aug)
104
+
105
+ return augmented
106
+
107
+
108
+ def augment_positive_events(
109
+ df: pd.DataFrame,
110
+ target_ratio: float = 0.05,
111
+ noise_scale: float = 0.05,
112
+ seed: int = 42,
113
+ ) -> pd.DataFrame:
114
+ """
115
+ Augment the positive (high-risk) class to reach target_ratio.
116
+
117
+ Args:
118
+ df: full training DataFrame with event_id, risk columns
119
+ target_ratio: desired fraction of high-risk events (default 5%)
120
+ noise_scale: std dev of Gaussian noise as fraction of feature std
121
+ seed: random seed
122
+
123
+ Returns:
124
+ Augmented DataFrame with new synthetic positive events appended
125
+ """
126
+ rng = np.random.default_rng(seed)
127
+
128
+ # Find positive events
129
+ event_risks = df.groupby("event_id")["risk"].last()
130
+ pos_event_ids = event_risks[event_risks > -5].index.tolist()
131
+ neg_event_ids = event_risks[event_risks <= -5].index.tolist()
132
+
133
+ n_pos = len(pos_event_ids)
134
+ n_neg = len(neg_event_ids)
135
+ n_total = n_pos + n_neg
136
+
137
+ # How many positive events do we need?
138
+ target_pos = int(target_ratio * (n_total / (1 - target_ratio)))
139
+ n_needed = max(0, target_pos - n_pos)
140
+
141
+ if n_needed == 0:
142
+ print(f"Already at target ratio ({n_pos}/{n_total} = {n_pos/n_total:.1%})")
143
+ return df
144
+
145
+ print(f"Augmenting: {n_pos} positive events → {n_pos + n_needed} "
146
+ f"(target {target_ratio:.0%} of {n_total + n_needed})")
147
+
148
+ # Generate augmented events
149
+ max_event_id = df["event_id"].max()
150
+ augmented_dfs = []
151
+ generated = 0
152
+
153
+ while generated < n_needed:
154
+ # Pick a random positive event to augment
155
+ src_event_id = rng.choice(pos_event_ids)
156
+ src_event = df[df["event_id"] == src_event_id]
157
+
158
+ # Apply noise augmentation
159
+ aug_variants = augment_event_noise(
160
+ src_event, noise_scale=noise_scale, n_augments=1, rng=rng
161
+ )
162
+
163
+ # Also try truncation sometimes
164
+ if rng.random() < 0.3 and len(src_event) > 3:
165
+ trunc_variants = augment_event_truncate(
166
+ src_event, n_augments=1, rng=rng
167
+ )
168
+ aug_variants.extend(trunc_variants)
169
+
170
+ for aug_df in aug_variants:
171
+ if generated >= n_needed:
172
+ break
173
+ max_event_id += 1
174
+ aug_df = aug_df.copy()
175
+ aug_df["event_id"] = max_event_id
176
+ aug_df["source"] = "augmented"
177
+ augmented_dfs.append(aug_df)
178
+ generated += 1
179
+
180
+ if augmented_dfs:
181
+ augmented = pd.concat(augmented_dfs, ignore_index=True)
182
+ result = pd.concat([df, augmented], ignore_index=True)
183
+
184
+ # Verify
185
+ event_risks = result.groupby("event_id")["risk"].last()
186
+ new_pos = (event_risks > -5).sum()
187
+ new_total = len(event_risks)
188
+ print(f"Result: {new_pos} positive / {new_total} total "
189
+ f"({new_pos/new_total:.1%})")
190
+ return result
191
+
192
+ return df
193
+
194
+
195
+ def integrate_spacetrack_positives(
196
+ kelvins_df: pd.DataFrame,
197
+ spacetrack_path: Path,
198
+ ) -> pd.DataFrame:
199
+ """
200
+ Add Space-Track emergency CDMs as additional positive training examples.
201
+
202
+ Since Space-Track cdm_public has only 16 features vs Kelvins' 103,
203
+ missing features are filled with 0. The model will learn to use whatever
204
+ features are available.
205
+ """
206
+ if not spacetrack_path.exists():
207
+ print(f"No Space-Track data at {spacetrack_path}")
208
+ return kelvins_df
209
+
210
+ from src.data.merge_sources import (
211
+ load_spacetrack_cdms, group_into_events, merge_datasets
212
+ )
213
+
214
+ st_df = load_spacetrack_cdms(spacetrack_path)
215
+ st_df = group_into_events(st_df)
216
+
217
+ merged = merge_datasets(kelvins_df, st_df)
218
+ return merged
219
+
220
+
221
+ def build_augmented_training_set(
222
+ data_dir: Path,
223
+ target_positive_ratio: float = 0.05,
224
+ noise_scale: float = 0.05,
225
+ seed: int = 42,
226
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
227
+ """
228
+ Build the full augmented training set from all available sources.
229
+
230
+ Steps:
231
+ 1. Load ESA Kelvins train/test
232
+ 2. Merge Space-Track emergency CDMs into training set
233
+ 3. Apply time-series augmentation to positive events
234
+ 4. Return (augmented_train, original_test)
235
+
236
+ Test set is NEVER augmented — it stays as Kelvins-only for fair evaluation.
237
+ """
238
+ from src.data.cdm_loader import load_dataset
239
+
240
+ print("=" * 60)
241
+ print(" Building Augmented Training Set")
242
+ print("=" * 60)
243
+
244
+ # Step 1: Load Kelvins
245
+ print("\n1. Loading ESA Kelvins dataset ...")
246
+ train_df, test_df = load_dataset(data_dir / "cdm")
247
+
248
+ # Defragment and tag source
249
+ train_df = train_df.copy()
250
+ test_df = test_df.copy()
251
+ train_df["source"] = "kelvins"
252
+ test_df["source"] = "kelvins"
253
+
254
+ # Count initial positives
255
+ event_risks = train_df.groupby("event_id")["risk"].last()
256
+ n_pos_initial = (event_risks > -5).sum()
257
+ n_total_initial = len(event_risks)
258
+ print(f" Initial: {n_pos_initial} positive / {n_total_initial} total "
259
+ f"({n_pos_initial/n_total_initial:.2%})")
260
+
261
+ # Step 2: Space-Track integration
262
+ st_path = data_dir / "cdm_spacetrack" / "cdm_spacetrack_emergency.csv"
263
+ if st_path.exists():
264
+ print(f"\n2. Integrating Space-Track emergency CDMs ...")
265
+ train_df = integrate_spacetrack_positives(train_df, st_path)
266
+ else:
267
+ print(f"\n2. No Space-Track data found (skipping)")
268
+
269
+ # Step 3: Time-series augmentation
270
+ print(f"\n3. Augmenting positive events (target ratio: {target_positive_ratio:.0%}) ...")
271
+ train_df = augment_positive_events(
272
+ train_df,
273
+ target_ratio=target_positive_ratio,
274
+ noise_scale=noise_scale,
275
+ seed=seed,
276
+ )
277
+
278
+ # Final stats
279
+ event_risks = train_df.groupby("event_id")["risk"].last()
280
+ event_sources = train_df.groupby("event_id")["source"].first()
281
+ n_kelvins = (event_sources == "kelvins").sum()
282
+ n_spacetrack = (event_sources == "spacetrack").sum()
283
+ n_augmented = (event_sources == "augmented").sum()
284
+ n_pos_final = (event_risks > -5).sum()
285
+ n_total_final = len(event_risks)
286
+
287
+ print(f"\n{'=' * 60}")
288
+ print(f" Final Training Set:")
289
+ print(f" Kelvins events: {n_kelvins}")
290
+ print(f" Space-Track events: {n_spacetrack}")
291
+ print(f" Augmented events: {n_augmented}")
292
+ print(f" Total events: {n_total_final}")
293
+ print(f" Positive events: {n_pos_final} ({n_pos_final/n_total_final:.1%})")
294
+ print(f" Total CDM rows: {len(train_df)}")
295
+ print(f"{'=' * 60}")
296
+
297
+ return train_df, test_df
src/data/cdm_loader.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by Claude Code -- 2026-02-08
2
+ """Load and parse ESA Kelvins CDM dataset into structured formats."""
3
+
4
+ import pandas as pd
5
+ import numpy as np
6
+ from pathlib import Path
7
+ from dataclasses import dataclass, field
8
+ from typing import List, Optional
9
+
10
+
11
+ @dataclass
12
+ class CDMSnapshot:
13
+ """A single Conjunction Data Message update."""
14
+ time_to_tca: float
15
+ miss_distance: float
16
+ relative_speed: float
17
+ risk: float
18
+ features: np.ndarray # all numeric columns as a flat vector
19
+
20
+
21
+ @dataclass
22
+ class ConjunctionEvent:
23
+ """A complete conjunction event = sequence of CDM snapshots."""
24
+ event_id: int
25
+ cdm_sequence: List[CDMSnapshot] = field(default_factory=list)
26
+ risk_label: int = 0 # 1 if any CDM in sequence has high risk
27
+ final_miss_distance: float = 0.0
28
+ altitude_km: float = 0.0
29
+ object_type: str = ""
30
+
31
+
32
+ # Columns we use for the feature vector (numeric only, excluding IDs/targets)
33
+ EXCLUDE_COLS = {"event_id", "time_to_tca", "risk", "mission_id"}
34
+
35
+
36
+ def load_cdm_csv(path: Path) -> pd.DataFrame:
37
+ """Load a CDM CSV and do basic cleaning."""
38
+ df = pd.read_csv(path)
39
+
40
+ # Identify numeric columns for features
41
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
42
+ feature_cols = [c for c in numeric_cols if c not in EXCLUDE_COLS]
43
+
44
+ # Fill NaN with 0 for numeric features (some covariance cols are sparse)
45
+ df[feature_cols] = df[feature_cols].fillna(0)
46
+
47
+ return df
48
+
49
+
50
+ def load_dataset(data_dir: Path) -> tuple[pd.DataFrame, pd.DataFrame]:
51
+ """Load train and test CDM DataFrames."""
52
+ # Find the CSV files (may be in subdirectory after extraction)
53
+ train_candidates = list(data_dir.rglob("*train*.csv"))
54
+ test_candidates = list(data_dir.rglob("*test*.csv"))
55
+
56
+ if not train_candidates:
57
+ raise FileNotFoundError(f"No train CSV found in {data_dir}")
58
+ if not test_candidates:
59
+ raise FileNotFoundError(f"No test CSV found in {data_dir}")
60
+
61
+ train_path = train_candidates[0]
62
+ test_path = test_candidates[0]
63
+
64
+ print(f"Loading train: {train_path}")
65
+ print(f"Loading test: {test_path}")
66
+
67
+ train_df = load_cdm_csv(train_path)
68
+ test_df = load_cdm_csv(test_path)
69
+
70
+ print(f"Train: {len(train_df)} rows, {train_df['event_id'].nunique()} events")
71
+ print(f"Test: {len(test_df)} rows, {test_df['event_id'].nunique()} events")
72
+
73
+ return train_df, test_df
74
+
75
+
76
+ def get_feature_columns(df: pd.DataFrame) -> list[str]:
77
+ """Get the list of numeric feature columns (excluding IDs and targets)."""
78
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
79
+ return [c for c in numeric_cols if c not in EXCLUDE_COLS]
80
+
81
+
82
+ def build_events(df: pd.DataFrame, feature_cols: list[str] = None) -> list[ConjunctionEvent]:
83
+ """Group CDM rows by event_id into ConjunctionEvent objects (vectorized).
84
+
85
+ Args:
86
+ df: CDM DataFrame
87
+ feature_cols: optional fixed list of feature columns (for train/test consistency)
88
+ """
89
+ if feature_cols is None:
90
+ feature_cols = get_feature_columns(df)
91
+ else:
92
+ # Ensure all requested columns exist; fill missing with 0
93
+ for col in feature_cols:
94
+ if col not in df.columns:
95
+ df = df.copy()
96
+ df[col] = 0.0
97
+ events = []
98
+
99
+ # Pre-extract feature matrix as float64 (avoids per-row pandas indexing)
100
+ feature_matrix = df[feature_cols].values # (N, F) float64
101
+ feature_matrix = np.nan_to_num(feature_matrix, nan=0.0, posinf=0.0, neginf=0.0)
102
+
103
+ # Sort entire dataframe by event_id then time_to_tca descending
104
+ df = df.copy()
105
+ df["_row_idx"] = np.arange(len(df))
106
+ df = df.sort_values(["event_id", "time_to_tca"], ascending=[True, False])
107
+
108
+ # Determine altitude column
109
+ alt_col = None
110
+ for col in ["t_h_apo", "c_h_apo"]:
111
+ if col in df.columns:
112
+ alt_col = col
113
+ break
114
+
115
+ has_miss = "miss_distance" in df.columns
116
+ has_speed = "relative_speed" in df.columns
117
+ has_risk = "risk" in df.columns
118
+ has_obj_type = "c_object_type" in df.columns
119
+
120
+ for event_id, group in df.groupby("event_id", sort=True):
121
+ row_indices = group["_row_idx"].values
122
+
123
+ # Build CDM sequence using pre-extracted arrays
124
+ cdm_seq = []
125
+ for ridx in row_indices:
126
+ snap = CDMSnapshot(
127
+ time_to_tca=float(df.iloc[ridx]["time_to_tca"]) if "time_to_tca" in df.columns else 0.0,
128
+ miss_distance=float(df.iloc[ridx]["miss_distance"]) if has_miss else 0.0,
129
+ relative_speed=float(df.iloc[ridx]["relative_speed"]) if has_speed else 0.0,
130
+ risk=float(df.iloc[ridx]["risk"]) if has_risk else 0.0,
131
+ features=feature_matrix[ridx].astype(np.float32),
132
+ )
133
+ cdm_seq.append(snap)
134
+
135
+ final_cdm = cdm_seq[-1]
136
+ risk_label = 1 if final_cdm.risk > -5 else 0
137
+ alt = float(group[alt_col].iloc[-1]) if alt_col else 0.0
138
+ obj_type = str(group["c_object_type"].iloc[0]) if has_obj_type else "unknown"
139
+
140
+ events.append(ConjunctionEvent(
141
+ event_id=int(event_id),
142
+ cdm_sequence=cdm_seq,
143
+ risk_label=risk_label,
144
+ final_miss_distance=final_cdm.miss_distance,
145
+ altitude_km=alt,
146
+ object_type=obj_type,
147
+ ))
148
+
149
+ n_high = sum(e.risk_label for e in events)
150
+ print(f"Built {len(events)} events, {n_high} high-risk ({100*n_high/len(events):.1f}%)")
151
+ return events
152
+
153
+
154
+ def events_to_flat_features(events: list[ConjunctionEvent]) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
155
+ """
156
+ Extract flat feature vectors from events for classical ML.
157
+ Uses the LAST CDM snapshot (closest to TCA) + temporal trend features.
158
+
159
+ Returns: (X, y_risk, y_miss)
160
+ """
161
+ X_list = []
162
+ y_risk = []
163
+ y_miss = []
164
+
165
+ for event in events:
166
+ seq = event.cdm_sequence
167
+ last = seq[-1]
168
+ base = last.features.copy()
169
+
170
+ miss_values = np.array([s.miss_distance for s in seq])
171
+ risk_values = np.array([s.risk for s in seq])
172
+ tca_values = np.array([s.time_to_tca for s in seq])
173
+
174
+ n_cdms = len(seq)
175
+ miss_mean = float(np.mean(miss_values)) if n_cdms > 0 else 0.0
176
+ miss_std = float(np.std(miss_values)) if n_cdms > 1 else 0.0
177
+
178
+ miss_trend = 0.0
179
+ if n_cdms > 1 and np.std(tca_values) > 0:
180
+ miss_trend = float(np.polyfit(tca_values, miss_values, 1)[0])
181
+
182
+ risk_trend = 0.0
183
+ if n_cdms > 1 and np.std(tca_values) > 0:
184
+ risk_trend = float(np.polyfit(tca_values, risk_values, 1)[0])
185
+
186
+ temporal_feats = np.array([
187
+ n_cdms,
188
+ miss_mean,
189
+ miss_std,
190
+ miss_trend,
191
+ risk_trend,
192
+ float(miss_values[0] - miss_values[-1]) if n_cdms > 1 else 0.0,
193
+ last.time_to_tca,
194
+ last.relative_speed,
195
+ ], dtype=np.float32)
196
+
197
+ combined = np.concatenate([base, temporal_feats])
198
+ X_list.append(combined)
199
+ y_risk.append(event.risk_label)
200
+ y_miss.append(np.log1p(max(event.final_miss_distance, 0.0)))
201
+
202
+ X = np.stack(X_list)
203
+ X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
204
+
205
+ return X, np.array(y_risk), np.array(y_miss)
src/data/counterfactual.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SGP4 counterfactual propagation — "what if no maneuver?" simulation.
2
+
3
+ For each likely-avoidance maneuver, propagates the pre-maneuver TLE forward
4
+ to estimate whether a close approach would have occurred. This generates
5
+ counterfactual "would-have-collided" labels for training enrichment.
6
+
7
+ Uses the sgp4 library for efficient satellite propagation.
8
+ """
9
+
10
+ import math
11
+ import numpy as np
12
+ from datetime import datetime, timedelta, timezone
13
+
14
+ try:
15
+ from sgp4.api import Satrec, WGS72
16
+ from sgp4 import exporter
17
+ SGP4_AVAILABLE = True
18
+ except ImportError:
19
+ SGP4_AVAILABLE = False
20
+
21
+ # Earth parameters
22
+ EARTH_RADIUS_KM = 6378.137
23
+
24
+ # Counterfactual thresholds
25
+ COLLISION_THRESHOLD_KM = 1.0 # "Would have collided" if closer than this
26
+ NEARBY_ALT_BAND_KM = 50.0 # Altitude proximity for neighbor selection
27
+ NEARBY_RAAN_BAND_DEG = 30.0 # RAAN proximity for neighbor selection
28
+
29
+
30
+ def celestrak_json_to_satrec(tle_json: dict) -> "Satrec":
31
+ """Convert a CelesTrak GP JSON record to an sgp4 Satrec object.
32
+
33
+ CelesTrak JSON includes TLE_LINE1/TLE_LINE2 when available. Falls
34
+ back to constructing from orbital elements via sgp4init().
35
+
36
+ Args:
37
+ tle_json: CelesTrak GP JSON dict with orbital elements.
38
+
39
+ Returns:
40
+ sgp4 Satrec object ready for propagation.
41
+
42
+ Raises:
43
+ ImportError: If sgp4 is not installed.
44
+ ValueError: If TLE data is insufficient.
45
+ """
46
+ if not SGP4_AVAILABLE:
47
+ raise ImportError("sgp4 library is required: pip install sgp4")
48
+
49
+ # Prefer TLE lines if available (most reliable)
50
+ line1 = tle_json.get("TLE_LINE1", "")
51
+ line2 = tle_json.get("TLE_LINE2", "")
52
+ if line1 and line2:
53
+ return Satrec.twoline2rv(line1, line2)
54
+
55
+ # Construct from JSON orbital elements using sgp4init
56
+ satrec = Satrec()
57
+
58
+ # Parse epoch
59
+ epoch_str = tle_json.get("EPOCH", "")
60
+ if not epoch_str:
61
+ raise ValueError("No EPOCH in TLE JSON")
62
+
63
+ epoch_dt = datetime.fromisoformat(epoch_str.replace("Z", "+00:00"))
64
+ if epoch_dt.tzinfo is None:
65
+ epoch_dt = epoch_dt.replace(tzinfo=timezone.utc)
66
+
67
+ # Convert to Julian date pair for sgp4
68
+ year = epoch_dt.year
69
+ mon = epoch_dt.month
70
+ day = epoch_dt.day
71
+ hr = epoch_dt.hour
72
+ minute = epoch_dt.minute
73
+ sec = epoch_dt.second + epoch_dt.microsecond / 1e6
74
+
75
+ # sgp4init expects elements in specific units
76
+ no_kozai = float(tle_json.get("MEAN_MOTION", 0)) * (2.0 * math.pi / 1440.0) # rev/day -> rad/min
77
+ ecco = float(tle_json.get("ECCENTRICITY", 0))
78
+ inclo = math.radians(float(tle_json.get("INCLINATION", 0)))
79
+ nodeo = math.radians(float(tle_json.get("RA_OF_ASC_NODE", 0)))
80
+ argpo = math.radians(float(tle_json.get("ARG_OF_PERICENTER", 0)))
81
+ mo = math.radians(float(tle_json.get("MEAN_ANOMALY", 0)))
82
+ bstar = float(tle_json.get("BSTAR", 0))
83
+ norad_id = int(tle_json.get("NORAD_CAT_ID", 0))
84
+
85
+ # Epoch in Julian date
86
+ jd_base = _datetime_to_jd(epoch_dt)
87
+ epoch_jd = jd_base
88
+ # sgp4init epoch is minutes since 1949-12-31 00:00 UTC
89
+ # But the Python API uses (jdsatepoch, jdsatepochF) pair
90
+ jd_whole = int(epoch_jd)
91
+ jd_frac = epoch_jd - jd_whole
92
+
93
+ satrec.sgp4init(
94
+ WGS72, # gravity model
95
+ 'i', # 'a' = old AFSPC mode, 'i' = improved
96
+ norad_id, # NORAD catalog number
97
+ (epoch_jd - 2433281.5), # epoch in days since 1949 Dec 31 00:00 UT
98
+ bstar, # BSTAR drag term
99
+ 0.0, # ndot (not used in sgp4init 'i' mode)
100
+ 0.0, # nddot (not used)
101
+ ecco, # eccentricity
102
+ argpo, # argument of perigee (radians)
103
+ inclo, # inclination (radians)
104
+ mo, # mean anomaly (radians)
105
+ no_kozai, # mean motion (radians/minute)
106
+ nodeo, # RAAN (radians)
107
+ )
108
+
109
+ return satrec
110
+
111
+
112
+ def _datetime_to_jd(dt: datetime) -> float:
113
+ """Convert datetime to Julian Date."""
114
+ if dt.tzinfo is not None:
115
+ dt = dt.astimezone(timezone.utc).replace(tzinfo=None)
116
+ a = (14 - dt.month) // 12
117
+ y = dt.year + 4800 - a
118
+ m = dt.month + 12 * a - 3
119
+ jdn = dt.day + (153 * m + 2) // 5 + 365 * y + y // 4 - y // 100 + y // 400 - 32045
120
+ jd = jdn + (dt.hour - 12) / 24.0 + dt.minute / 1440.0 + dt.second / 86400.0
121
+ return jd
122
+
123
+
124
+ def _propagate_positions(satrec: "Satrec", start_jd: float, hours: float, step_min: float) -> np.ndarray:
125
+ """Propagate a satellite and return position array (N x 3) in km.
126
+
127
+ Returns empty array if propagation fails.
128
+ """
129
+ n_steps = int(hours * 60 / step_min) + 1
130
+ positions = []
131
+
132
+ for i in range(n_steps):
133
+ minutes_since_epoch = (start_jd - satrec.jdsatepoch - satrec.jdsatepochF) * 1440.0 + i * step_min
134
+ e, r, v = satrec.sgp4(satrec.jdsatepoch, satrec.jdsatepochF + minutes_since_epoch / 1440.0)
135
+ if e != 0:
136
+ continue
137
+ positions.append(r)
138
+
139
+ if not positions:
140
+ return np.array([]).reshape(0, 3)
141
+ return np.array(positions)
142
+
143
+
144
+ def find_nearby_satellites(
145
+ maneuvered_tle: dict,
146
+ all_tles: list[dict],
147
+ alt_band_km: float = NEARBY_ALT_BAND_KM,
148
+ raan_band_deg: float = NEARBY_RAAN_BAND_DEG,
149
+ ) -> list[dict]:
150
+ """Find satellites in similar orbital shell to the maneuvered object."""
151
+ from src.data.maneuver_detector import mean_motion_to_sma, sma_to_altitude
152
+
153
+ norad_id = int(maneuvered_tle.get("NORAD_CAT_ID", 0))
154
+ mm = float(maneuvered_tle.get("MEAN_MOTION", 0))
155
+ target_alt = sma_to_altitude(mean_motion_to_sma(mm))
156
+ target_raan = float(maneuvered_tle.get("RA_OF_ASC_NODE", 0))
157
+
158
+ nearby = []
159
+ for tle in all_tles:
160
+ tid = int(tle.get("NORAD_CAT_ID", 0))
161
+ if tid == norad_id or tid <= 0:
162
+ continue
163
+
164
+ t_mm = float(tle.get("MEAN_MOTION", 0))
165
+ t_alt = sma_to_altitude(mean_motion_to_sma(t_mm))
166
+ t_raan = float(tle.get("RA_OF_ASC_NODE", 0))
167
+
168
+ alt_diff = abs(t_alt - target_alt)
169
+ raan_diff = abs(t_raan - target_raan)
170
+ raan_diff = min(raan_diff, 360.0 - raan_diff)
171
+
172
+ if alt_diff < alt_band_km and raan_diff < raan_band_deg:
173
+ nearby.append(tle)
174
+
175
+ return nearby
176
+
177
+
178
+ def propagate_counterfactual(
179
+ pre_maneuver_tle: dict,
180
+ nearby_tles: list[dict],
181
+ hours_forward: float = 24.0,
182
+ step_minutes: float = 10.0,
183
+ ) -> dict:
184
+ """Simulate "what if no maneuver?" using SGP4 propagation.
185
+
186
+ Propagates the pre-maneuver TLE (before orbit change) forward and
187
+ checks for close approaches with nearby satellites.
188
+
189
+ Args:
190
+ pre_maneuver_tle: Yesterday's TLE for the maneuvered satellite.
191
+ nearby_tles: Current TLEs for nearby satellites.
192
+ hours_forward: How far to propagate (hours).
193
+ step_minutes: Time step for propagation (minutes).
194
+
195
+ Returns:
196
+ Dict with: min_distance_km, time_of_closest_approach,
197
+ would_have_collided, closest_norad_id, n_neighbors_checked.
198
+ """
199
+ if not SGP4_AVAILABLE:
200
+ return {
201
+ "min_distance_km": None,
202
+ "would_have_collided": False,
203
+ "error": "sgp4 not installed",
204
+ }
205
+
206
+ try:
207
+ target_sat = celestrak_json_to_satrec(pre_maneuver_tle)
208
+ except (ValueError, Exception) as e:
209
+ return {
210
+ "min_distance_km": None,
211
+ "would_have_collided": False,
212
+ "error": f"target TLE parse failed: {e}",
213
+ }
214
+
215
+ # Use current time as propagation start
216
+ now = datetime.now(timezone.utc)
217
+ start_jd = _datetime_to_jd(now)
218
+
219
+ # Propagate maneuvered satellite (pre-maneuver orbit)
220
+ target_positions = _propagate_positions(target_sat, start_jd, hours_forward, step_minutes)
221
+ if len(target_positions) == 0:
222
+ return {
223
+ "min_distance_km": None,
224
+ "would_have_collided": False,
225
+ "error": "target propagation failed",
226
+ }
227
+
228
+ global_min_dist = float("inf")
229
+ closest_norad = 0
230
+ closest_time_offset_min = 0.0
231
+ n_checked = 0
232
+
233
+ for neighbor_tle in nearby_tles:
234
+ try:
235
+ neighbor_sat = celestrak_json_to_satrec(neighbor_tle)
236
+ except (ValueError, Exception):
237
+ continue
238
+
239
+ neighbor_positions = _propagate_positions(neighbor_sat, start_jd, hours_forward, step_minutes)
240
+ if len(neighbor_positions) == 0:
241
+ continue
242
+
243
+ n_checked += 1
244
+
245
+ # Compute distances at each timestep (use min of overlapping steps)
246
+ n_common = min(len(target_positions), len(neighbor_positions))
247
+ diffs = target_positions[:n_common] - neighbor_positions[:n_common]
248
+ distances = np.linalg.norm(diffs, axis=1)
249
+ min_idx = np.argmin(distances)
250
+ min_dist = distances[min_idx]
251
+
252
+ if min_dist < global_min_dist:
253
+ global_min_dist = min_dist
254
+ closest_norad = int(neighbor_tle.get("NORAD_CAT_ID", 0))
255
+ closest_time_offset_min = min_idx * step_minutes
256
+
257
+ if global_min_dist == float("inf"):
258
+ return {
259
+ "min_distance_km": None,
260
+ "would_have_collided": False,
261
+ "n_neighbors_checked": n_checked,
262
+ "error": "no valid neighbors propagated",
263
+ }
264
+
265
+ tca_dt = now + timedelta(minutes=closest_time_offset_min)
266
+
267
+ return {
268
+ "min_distance_km": round(global_min_dist, 3),
269
+ "time_of_closest_approach": tca_dt.isoformat(),
270
+ "would_have_collided": global_min_dist < COLLISION_THRESHOLD_KM,
271
+ "closest_norad_id": closest_norad,
272
+ "n_neighbors_checked": n_checked,
273
+ }
274
+
275
+
276
+ def compute_forward_trajectory(
277
+ tle_1: dict,
278
+ tle_2: dict,
279
+ hours_forward: float = 120.0,
280
+ step_minutes: float = 20.0,
281
+ ) -> list[dict] | None:
282
+ """Compute full trajectory time series for two satellites.
283
+
284
+ Returns list of trajectory points with ECI positions and separation
285
+ distance, suitable for baking into the webapp alerts JSON so the
286
+ frontend doesn't need to do SGP4 propagation or load TLE data.
287
+
288
+ Args:
289
+ tle_1: CelesTrak GP JSON for satellite 1.
290
+ tle_2: CelesTrak GP JSON for satellite 2.
291
+ hours_forward: How far to propagate (default 120h = 5 days).
292
+ step_minutes: Time step for propagation (minutes).
293
+
294
+ Returns:
295
+ List of dicts with: h (hours from start), d (distance km),
296
+ s1 [x,y,z] ECI km, s2 [x,y,z] ECI km. None if propagation fails.
297
+ """
298
+ if not SGP4_AVAILABLE:
299
+ return None
300
+
301
+ try:
302
+ sat1 = celestrak_json_to_satrec(tle_1)
303
+ sat2 = celestrak_json_to_satrec(tle_2)
304
+ except (ValueError, Exception):
305
+ return None
306
+
307
+ now = datetime.now(timezone.utc)
308
+ start_jd = _datetime_to_jd(now)
309
+
310
+ n_steps = int(hours_forward * 60 / step_minutes) + 1
311
+ points = []
312
+
313
+ for i in range(n_steps):
314
+ mins = i * step_minutes
315
+ target_jd = start_jd + mins / 1440.0
316
+ jd_whole = int(target_jd)
317
+ jd_frac = target_jd - jd_whole
318
+
319
+ e1, r1, _ = sat1.sgp4(jd_whole, jd_frac)
320
+ e2, r2, _ = sat2.sgp4(jd_whole, jd_frac)
321
+
322
+ if e1 != 0 or e2 != 0:
323
+ continue
324
+ if not all(math.isfinite(v) for v in r1 + r2):
325
+ continue
326
+
327
+ dx = r1[0] - r2[0]
328
+ dy = r1[1] - r2[1]
329
+ dz = r1[2] - r2[2]
330
+ dist = math.sqrt(dx * dx + dy * dy + dz * dz)
331
+
332
+ points.append({
333
+ "h": round(mins / 60.0, 2),
334
+ "d": round(dist, 1),
335
+ "s1": [round(r1[0], 1), round(r1[1], 1), round(r1[2], 1)],
336
+ "s2": [round(r2[0], 1), round(r2[1], 1), round(r2[2], 1)],
337
+ })
338
+
339
+ return points if points else None
340
+
341
+
342
+ def compute_tca_trail(
343
+ tle_1: dict,
344
+ tle_2: dict,
345
+ tca_hours: float,
346
+ half_window_min: float = 30.0,
347
+ step_minutes: float = 0.25,
348
+ ) -> list[dict] | None:
349
+ """Compute dense trail around TCA for globe orbital path visualization.
350
+
351
+ Returns 15-sec resolution positions for ±30 min around TCA.
352
+
353
+ Args:
354
+ tle_1: CelesTrak GP JSON for satellite 1.
355
+ tle_2: CelesTrak GP JSON for satellite 2.
356
+ tca_hours: Hours from now to TCA (from compute_forward_tca).
357
+ half_window_min: Half window in minutes around TCA.
358
+ step_minutes: Time step in minutes.
359
+
360
+ Returns:
361
+ List of dicts with s1 [x,y,z] and s2 [x,y,z] ECI km. None if fails.
362
+ """
363
+ if not SGP4_AVAILABLE:
364
+ return None
365
+
366
+ try:
367
+ sat1 = celestrak_json_to_satrec(tle_1)
368
+ sat2 = celestrak_json_to_satrec(tle_2)
369
+ except (ValueError, Exception):
370
+ return None
371
+
372
+ now = datetime.now(timezone.utc)
373
+ start_jd = _datetime_to_jd(now)
374
+
375
+ tca_min = tca_hours * 60.0
376
+ t_start = tca_min - half_window_min
377
+ t_end = tca_min + half_window_min
378
+ n_steps = int((t_end - t_start) / step_minutes) + 1
379
+
380
+ trail = []
381
+ for i in range(n_steps):
382
+ mins = t_start + i * step_minutes
383
+ target_jd = start_jd + mins / 1440.0
384
+ jd_whole = int(target_jd)
385
+ jd_frac = target_jd - jd_whole
386
+
387
+ e1, r1, _ = sat1.sgp4(jd_whole, jd_frac)
388
+ e2, r2, _ = sat2.sgp4(jd_whole, jd_frac)
389
+
390
+ if e1 != 0 or e2 != 0:
391
+ continue
392
+ if not all(math.isfinite(v) for v in r1 + r2):
393
+ continue
394
+
395
+ dx = r1[0] - r2[0]
396
+ dy = r1[1] - r2[1]
397
+ dz = r1[2] - r2[2]
398
+ dist = math.sqrt(dx * dx + dy * dy + dz * dz)
399
+
400
+ trail.append({
401
+ "h": round(mins / 60.0, 3),
402
+ "d": round(dist, 1),
403
+ "s1": [round(r1[0], 1), round(r1[1], 1), round(r1[2], 1)],
404
+ "s2": [round(r2[0], 1), round(r2[1], 1), round(r2[2], 1)],
405
+ })
406
+
407
+ return trail if trail else None
408
+
409
+
410
+ def compute_forward_tca(
411
+ tle_1: dict,
412
+ tle_2: dict,
413
+ hours_forward: float = 120.0,
414
+ step_minutes: float = 10.0,
415
+ ) -> dict:
416
+ """Compute forward Time of Closest Approach between two satellites.
417
+
418
+ Propagates both satellites forward using SGP4 and finds the minimum
419
+ separation distance and when it occurs.
420
+
421
+ Args:
422
+ tle_1: CelesTrak GP JSON for satellite 1.
423
+ tle_2: CelesTrak GP JSON for satellite 2.
424
+ hours_forward: How far to propagate (default 120h = 5 days).
425
+ step_minutes: Time step for propagation (minutes).
426
+
427
+ Returns:
428
+ Dict with: tca_hours, tca_min_distance_km, or error.
429
+ """
430
+ if not SGP4_AVAILABLE:
431
+ return {"tca_hours": None, "tca_min_distance_km": None}
432
+
433
+ try:
434
+ sat1 = celestrak_json_to_satrec(tle_1)
435
+ sat2 = celestrak_json_to_satrec(tle_2)
436
+ except (ValueError, Exception) as e:
437
+ return {"tca_hours": None, "tca_min_distance_km": None}
438
+
439
+ now = datetime.now(timezone.utc)
440
+ start_jd = _datetime_to_jd(now)
441
+
442
+ pos1 = _propagate_positions(sat1, start_jd, hours_forward, step_minutes)
443
+ pos2 = _propagate_positions(sat2, start_jd, hours_forward, step_minutes)
444
+
445
+ if len(pos1) == 0 or len(pos2) == 0:
446
+ return {"tca_hours": None, "tca_min_distance_km": None}
447
+
448
+ n_common = min(len(pos1), len(pos2))
449
+ diffs = pos1[:n_common] - pos2[:n_common]
450
+ distances = np.linalg.norm(diffs, axis=1)
451
+ min_idx = int(np.argmin(distances))
452
+ min_dist = float(distances[min_idx])
453
+ tca_hours = min_idx * step_minutes / 60.0
454
+
455
+ return {
456
+ "tca_hours": round(tca_hours, 1),
457
+ "tca_min_distance_km": round(min_dist, 1),
458
+ }
src/data/density_features.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by Claude Code — 2026-02-13
2
+ """Orbital density features derived from the CRASH Clock framework.
3
+
4
+ Computes population-level orbital density metrics for each conjunction event,
5
+ based on the altitude distribution of all events in the training set.
6
+
7
+ The key insight from Thiele et al. (2025) "An Orbital House of Cards":
8
+ collision rate scales as n² * A_col * v_r — so a conjunction at a crowded
9
+ altitude (550 km Starlink shell) is fundamentally riskier than the same
10
+ miss_distance at a sparse altitude (1200 km).
11
+
12
+ These features are computed from the TRAINING set only and applied to
13
+ validation/test sets to prevent data leakage.
14
+ """
15
+
16
+ import json
17
+ import numpy as np
18
+ import pandas as pd
19
+ from pathlib import Path
20
+
21
+ # Physical constants
22
+ EARTH_RADIUS_KM = 6371.0
23
+ GM_M3_S2 = 3.986004418e14 # Earth gravitational parameter (m³/s²)
24
+
25
+ # CRASH Clock cross-sections from Thiele et al. Table (10m-5m-10cm)
26
+ A_COL_SAT_SAT = 300.0 # m² (satellite-satellite, 10m approach)
27
+ A_COL_SAT_DEBRIS = 79.0 # m² (satellite-debris, 5m approach)
28
+
29
+ # Altitude binning
30
+ BIN_WIDTH_KM = 25 # km per altitude bin
31
+ ALT_MIN_KM = 150
32
+ ALT_MAX_KM = 2100
33
+
34
+ # Feature names that will be added to DataFrames
35
+ DENSITY_FEATURES = [
36
+ "shell_density", # events per km³ in altitude bin
37
+ "shell_collision_rate", # Γ from CRASH Clock Eq. 2 (per second)
38
+ "local_crash_clock_log", # log10(seconds to expected collision in shell)
39
+ "altitude_percentile", # CDF position in event altitude distribution
40
+ "n_events_in_shell", # raw count of training events at this altitude
41
+ "shell_risk_rate", # fraction of high-risk events in this altitude bin
42
+ ]
43
+
44
+
45
+ def _orbital_speed_kms(altitude_km: float) -> float:
46
+ """Circular orbital speed in km/s at a given altitude."""
47
+ r_m = (EARTH_RADIUS_KM + altitude_km) * 1000.0
48
+ return np.sqrt(GM_M3_S2 / r_m) / 1000.0 # m/s → km/s
49
+
50
+
51
+ def _mean_relative_speed_kms(altitude_km: float) -> float:
52
+ """Average relative encounter speed: v_r = (4/3) * v_orbital (Eq. 7)."""
53
+ return (4.0 / 3.0) * _orbital_speed_kms(altitude_km)
54
+
55
+
56
+ def _shell_volume_km3(altitude_km: float, width_km: float) -> float:
57
+ """Volume of a spherical shell at given altitude with given width."""
58
+ r_inner = EARTH_RADIUS_KM + altitude_km - width_km / 2.0
59
+ r_outer = EARTH_RADIUS_KM + altitude_km + width_km / 2.0
60
+ return (4.0 / 3.0) * np.pi * (r_outer**3 - r_inner**3)
61
+
62
+
63
+ class OrbitalDensityComputer:
64
+ """Computes orbital density features from a training DataFrame.
65
+
66
+ Fit on training data, then transform any DataFrame (train/val/test)
67
+ to add density-based static features per event.
68
+
69
+ The density is computed from event altitudes, NOT from a full TLE
70
+ catalog, so it represents the conjunction density distribution rather
71
+ than the full RSO population. For the Kelvins dataset, this captures
72
+ where conjunction events cluster (which correlates with RSO density).
73
+ """
74
+
75
+ def __init__(self, bin_width_km: float = BIN_WIDTH_KM):
76
+ self.bin_width_km = bin_width_km
77
+ self.bin_edges = np.arange(ALT_MIN_KM, ALT_MAX_KM + bin_width_km, bin_width_km)
78
+ self.bin_centers = (self.bin_edges[:-1] + self.bin_edges[1:]) / 2.0
79
+ self.n_bins = len(self.bin_centers)
80
+
81
+ # Fitted state (populated by fit())
82
+ self.event_counts = None # events per bin
83
+ self.density_per_bin = None # events / km³ per bin
84
+ self.collision_rate = None # Γ per bin (events/s)
85
+ self.crash_clock_log = None # log10(seconds to collision) per bin
86
+ self.risk_rate_per_bin = None # fraction high-risk per bin
87
+ self.altitude_cdf = None # cumulative distribution
88
+ self.is_fitted = False
89
+
90
+ def _event_altitude(self, df: pd.DataFrame) -> np.ndarray:
91
+ """Compute conjunction altitude for each event (last CDM row).
92
+
93
+ Uses mean of target and chaser perigee altitudes as the approximate
94
+ conjunction altitude. Falls back to semi-major axis minus Earth radius.
95
+ """
96
+ event_df = df.groupby("event_id").last()
97
+
98
+ # Primary: mean of perigee altitudes (where most conjunctions happen)
99
+ t_alt = np.zeros(len(event_df))
100
+ c_alt = np.zeros(len(event_df))
101
+
102
+ if "t_h_per" in event_df.columns:
103
+ t_alt = event_df["t_h_per"].fillna(0).values
104
+ elif "t_j2k_sma" in event_df.columns:
105
+ t_alt = event_df["t_j2k_sma"].fillna(EARTH_RADIUS_KM).values - EARTH_RADIUS_KM
106
+
107
+ if "c_h_per" in event_df.columns:
108
+ c_alt = event_df["c_h_per"].fillna(0).values
109
+ elif "c_j2k_sma" in event_df.columns:
110
+ c_alt = event_df["c_j2k_sma"].fillna(EARTH_RADIUS_KM).values - EARTH_RADIUS_KM
111
+
112
+ altitudes = (t_alt + c_alt) / 2.0
113
+ # Clamp to valid range
114
+ altitudes = np.clip(altitudes, ALT_MIN_KM, ALT_MAX_KM - 1)
115
+ return altitudes, event_df.index.values
116
+
117
+ def fit(self, train_df: pd.DataFrame) -> "OrbitalDensityComputer":
118
+ """Fit density distribution from training data.
119
+
120
+ Must be called before transform(). Only uses training data
121
+ to prevent information leakage into validation/test sets.
122
+ """
123
+ altitudes, event_ids = self._event_altitude(train_df)
124
+
125
+ # Histogram: count events per altitude bin
126
+ self.event_counts, _ = np.histogram(altitudes, bins=self.bin_edges)
127
+
128
+ # Density: events per km³ in each shell
129
+ volumes = np.array([
130
+ _shell_volume_km3(c, self.bin_width_km)
131
+ for c in self.bin_centers
132
+ ])
133
+ self.density_per_bin = self.event_counts / np.maximum(volumes, 1e-6)
134
+
135
+ # Collision rate per shell: Γ = (1/2) * n² * A_col * v_r * V
136
+ # Using satellite-satellite cross-section as the primary concern
137
+ self.collision_rate = np.zeros(self.n_bins)
138
+ for i, (center, density, volume) in enumerate(
139
+ zip(self.bin_centers, self.density_per_bin, volumes)
140
+ ):
141
+ v_r = _mean_relative_speed_kms(center) # km/s
142
+ # Convert A_col from m² to km², v_r already in km/s
143
+ a_col_km2 = A_COL_SAT_SAT / 1e6 # m² → km²
144
+ # Γ = 0.5 * n² * A * v_r * V (units: per second)
145
+ gamma = 0.5 * density**2 * a_col_km2 * v_r * volume
146
+ self.collision_rate[i] = gamma
147
+
148
+ # CRASH Clock per shell: τ = 1/Γ (in seconds), log10 for feature
149
+ with np.errstate(divide="ignore"):
150
+ tau = 1.0 / np.maximum(self.collision_rate, 1e-30)
151
+ self.crash_clock_log = np.log10(np.clip(tau, 1.0, 1e15))
152
+
153
+ # Risk rate per bin: fraction of positive events
154
+ risk_per_event = train_df.groupby("event_id")["risk"].last()
155
+ is_high_risk = (risk_per_event > -5).astype(float).values
156
+
157
+ self.risk_rate_per_bin = np.zeros(self.n_bins)
158
+ for i in range(self.n_bins):
159
+ mask = (altitudes >= self.bin_edges[i]) & (altitudes < self.bin_edges[i + 1])
160
+ if mask.sum() > 0:
161
+ self.risk_rate_per_bin[i] = is_high_risk[mask].mean()
162
+
163
+ # Cumulative altitude distribution for percentile feature
164
+ sorted_alts = np.sort(altitudes)
165
+ self.altitude_cdf = sorted_alts
166
+
167
+ self.is_fitted = True
168
+ print(f" OrbitalDensityComputer fitted on {len(event_ids)} events")
169
+ print(f" Altitude range: {altitudes.min():.0f} - {altitudes.max():.0f} km")
170
+ print(f" Peak density bin: {self.bin_centers[np.argmax(self.density_per_bin)]:.0f} km "
171
+ f"({self.event_counts.max()} events)")
172
+ peak_idx = np.argmax(self.collision_rate)
173
+ if self.collision_rate[peak_idx] > 0:
174
+ print(f" Highest collision rate: {self.bin_centers[peak_idx]:.0f} km "
175
+ f"(tau = {10**self.crash_clock_log[peak_idx]:.0f} s)")
176
+
177
+ return self
178
+
179
+ def _get_bin_index(self, altitudes: np.ndarray) -> np.ndarray:
180
+ """Map altitudes to bin indices."""
181
+ indices = np.digitize(altitudes, self.bin_edges) - 1
182
+ return np.clip(indices, 0, self.n_bins - 1)
183
+
184
+ def _altitude_percentile(self, altitudes: np.ndarray) -> np.ndarray:
185
+ """Compute percentile in the training altitude distribution."""
186
+ return np.searchsorted(self.altitude_cdf, altitudes) / len(self.altitude_cdf)
187
+
188
+ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
189
+ """Add density features to a CDM DataFrame.
190
+
191
+ Features are computed per event_id and broadcast to all CDM rows
192
+ (they're static features — same for every CDM in the sequence).
193
+ """
194
+ if not self.is_fitted:
195
+ raise RuntimeError("Must call fit() before transform()")
196
+
197
+ df = df.copy()
198
+ altitudes, event_ids = self._event_altitude(df)
199
+ bin_indices = self._get_bin_index(altitudes)
200
+
201
+ # Build event-level features
202
+ event_features = {}
203
+ for i, eid in enumerate(event_ids):
204
+ bi = bin_indices[i]
205
+ event_features[eid] = {
206
+ "shell_density": self.density_per_bin[bi],
207
+ "shell_collision_rate": self.collision_rate[bi],
208
+ "local_crash_clock_log": self.crash_clock_log[bi],
209
+ "altitude_percentile": self._altitude_percentile(
210
+ np.array([altitudes[i]])
211
+ )[0],
212
+ "n_events_in_shell": float(self.event_counts[bi]),
213
+ "shell_risk_rate": self.risk_rate_per_bin[bi],
214
+ }
215
+
216
+ # Map features to all CDM rows via event_id
217
+ for col in DENSITY_FEATURES:
218
+ df[col] = df["event_id"].map(
219
+ {eid: feats[col] for eid, feats in event_features.items()}
220
+ ).fillna(0.0)
221
+
222
+ return df
223
+
224
+ def save(self, path: Path):
225
+ """Save fitted state to JSON for inference."""
226
+ if not self.is_fitted:
227
+ raise RuntimeError("Must call fit() before save()")
228
+ state = {
229
+ "bin_width_km": self.bin_width_km,
230
+ "bin_edges": self.bin_edges.tolist(),
231
+ "bin_centers": self.bin_centers.tolist(),
232
+ "event_counts": self.event_counts.tolist(),
233
+ "density_per_bin": self.density_per_bin.tolist(),
234
+ "collision_rate": self.collision_rate.tolist(),
235
+ "crash_clock_log": self.crash_clock_log.tolist(),
236
+ "risk_rate_per_bin": self.risk_rate_per_bin.tolist(),
237
+ "altitude_cdf": self.altitude_cdf.tolist(),
238
+ }
239
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
240
+ with open(path, "w") as f:
241
+ json.dump(state, f, indent=2)
242
+
243
+ @classmethod
244
+ def load(cls, path: Path) -> "OrbitalDensityComputer":
245
+ """Load fitted state from JSON."""
246
+ with open(path) as f:
247
+ state = json.load(f)
248
+ obj = cls(bin_width_km=state["bin_width_km"])
249
+ obj.bin_edges = np.array(state["bin_edges"])
250
+ obj.bin_centers = np.array(state["bin_centers"])
251
+ obj.n_bins = len(obj.bin_centers)
252
+ obj.event_counts = np.array(state["event_counts"])
253
+ obj.density_per_bin = np.array(state["density_per_bin"])
254
+ obj.collision_rate = np.array(state["collision_rate"])
255
+ obj.crash_clock_log = np.array(state["crash_clock_log"])
256
+ obj.risk_rate_per_bin = np.array(state["risk_rate_per_bin"])
257
+ obj.altitude_cdf = np.array(state["altitude_cdf"])
258
+ obj.is_fitted = True
259
+ return obj
src/data/firebase_client.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by Claude Code -- 2026-02-13
2
+ """Firebase Firestore client for prediction logging.
3
+
4
+ Stores daily conjunction predictions and maneuver detection outcomes.
5
+ Uses the Firestore REST API to avoid heavy SDK dependencies.
6
+ Falls back to local JSONL logging if Firebase is not configured.
7
+
8
+ Environment variables:
9
+ FIREBASE_SERVICE_ACCOUNT: JSON string of the service account key
10
+ FIREBASE_PROJECT_ID: Project ID (auto-detected from service account if not set)
11
+ """
12
+
13
+ import os
14
+ import json
15
+ import time
16
+ import numpy as np
17
+ from pathlib import Path
18
+ from datetime import datetime, timezone
19
+
20
+
21
+ def _json_default(obj):
22
+ """Handle numpy types that json.dumps can't serialize."""
23
+ if isinstance(obj, (np.integer,)):
24
+ return int(obj)
25
+ if isinstance(obj, (np.floating,)):
26
+ return float(obj)
27
+ if isinstance(obj, (np.bool_,)):
28
+ return bool(obj)
29
+ if isinstance(obj, np.ndarray):
30
+ return obj.tolist()
31
+ raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
32
+
33
+ # Try to import google-cloud-firestore (lightweight)
34
+ try:
35
+ from google.cloud.firestore import Client as FirestoreClient
36
+ from google.oauth2.service_account import Credentials
37
+ HAS_FIRESTORE = True
38
+ except ImportError:
39
+ HAS_FIRESTORE = False
40
+
41
+
42
+ class PredictionLogger:
43
+ """Log predictions to Firebase Firestore with local JSONL fallback."""
44
+
45
+ def __init__(self, local_dir: Path = None):
46
+ self.db = None
47
+ self.local_dir = local_dir or Path("data/prediction_logs")
48
+ self.local_dir.mkdir(parents=True, exist_ok=True)
49
+ self._init_firebase()
50
+
51
+ def _init_firebase(self):
52
+ """Initialize Firebase Firestore client from environment."""
53
+ sa_json = os.environ.get("FIREBASE_SERVICE_ACCOUNT", "")
54
+ if not sa_json or not HAS_FIRESTORE:
55
+ if not HAS_FIRESTORE:
56
+ print(" Firebase SDK not installed (pip install google-cloud-firestore)")
57
+ print(" Using local JSONL logging only")
58
+ return
59
+
60
+ try:
61
+ sa_info = json.loads(sa_json)
62
+ creds = Credentials.from_service_account_info(sa_info)
63
+ project_id = sa_info.get("project_id", os.environ.get("FIREBASE_PROJECT_ID", ""))
64
+ self.db = FirestoreClient(project=project_id, credentials=creds)
65
+ print(f" Firebase Firestore connected (project: {project_id})")
66
+ except Exception as e:
67
+ print(f" Firebase init failed: {e}")
68
+ print(" Falling back to local JSONL logging")
69
+
70
+ def log_predictions(self, date_str: str, predictions: list[dict]):
71
+ """Log a batch of daily predictions.
72
+
73
+ Args:
74
+ date_str: Date string (YYYY-MM-DD)
75
+ predictions: List of prediction dicts with keys:
76
+ sat1_norad, sat2_norad, sat1_name, sat2_name,
77
+ risk_score, altitude_km, model_used
78
+ """
79
+ # Always save locally
80
+ local_file = self.local_dir / f"{date_str}.jsonl"
81
+ with open(local_file, "a") as f:
82
+ for pred in predictions:
83
+ pred["date"] = date_str
84
+ pred["logged_at"] = datetime.now(timezone.utc).isoformat()
85
+ f.write(json.dumps(pred, default=_json_default) + "\n")
86
+ print(f" Saved {len(predictions)} predictions to {local_file}")
87
+
88
+ # Firebase upload
89
+ if self.db:
90
+ try:
91
+ batch = self.db.batch()
92
+ collection = self.db.collection("predictions").document(date_str)
93
+ collection.set({"date": date_str, "count": len(predictions)})
94
+
95
+ for i, pred in enumerate(predictions):
96
+ doc_ref = self.db.collection("predictions").document(date_str) \
97
+ .collection("pairs").document(f"pair_{i:04d}")
98
+ batch.set(doc_ref, pred)
99
+
100
+ batch.commit()
101
+ print(f" Uploaded {len(predictions)} predictions to Firebase")
102
+ except Exception as e:
103
+ print(f" Firebase upload failed: {e}")
104
+
105
+ def log_outcomes(self, date_str: str, outcomes: list[dict]):
106
+ """Log maneuver detection outcomes for a previous prediction date.
107
+
108
+ Args:
109
+ date_str: Original prediction date (YYYY-MM-DD)
110
+ outcomes: List of outcome dicts with keys:
111
+ sat1_norad, sat2_norad, sat1_maneuvered, sat2_maneuvered,
112
+ sat1_delta_a_m, sat2_delta_a_m, validated_at
113
+ """
114
+ local_file = self.local_dir / f"{date_str}_outcomes.jsonl"
115
+ with open(local_file, "a") as f:
116
+ for outcome in outcomes:
117
+ outcome["prediction_date"] = date_str
118
+ outcome["validated_at"] = datetime.now(timezone.utc).isoformat()
119
+ f.write(json.dumps(outcome, default=_json_default) + "\n")
120
+ print(f" Saved {len(outcomes)} outcomes to {local_file}")
121
+
122
+ if self.db:
123
+ try:
124
+ batch = self.db.batch()
125
+ for i, outcome in enumerate(outcomes):
126
+ doc_ref = self.db.collection("outcomes").document(date_str) \
127
+ .collection("results").document(f"result_{i:04d}")
128
+ batch.set(doc_ref, outcome)
129
+ batch.commit()
130
+ print(f" Uploaded {len(outcomes)} outcomes to Firebase")
131
+ except Exception as e:
132
+ print(f" Firebase upload failed: {e}")
133
+
134
+ def log_daily_summary(self, date_str: str, summary: dict):
135
+ """Log a daily summary (n_predictions, n_maneuvers_detected, accuracy, etc)."""
136
+ local_file = self.local_dir / "daily_summaries.jsonl"
137
+ summary["date"] = date_str
138
+ with open(local_file, "a") as f:
139
+ f.write(json.dumps(summary, default=_json_default) + "\n")
140
+
141
+ if self.db:
142
+ try:
143
+ self.db.collection("daily_summaries").document(date_str).set(summary)
144
+ print(f" Uploaded daily summary to Firebase")
145
+ except Exception as e:
146
+ print(f" Firebase summary upload failed: {e}")
147
+
148
+ def get_predictions_for_date(self, date_str: str) -> list[dict]:
149
+ """Retrieve predictions for a date (from local files)."""
150
+ local_file = self.local_dir / f"{date_str}.jsonl"
151
+ if not local_file.exists():
152
+ return []
153
+ predictions = []
154
+ with open(local_file) as f:
155
+ for line in f:
156
+ line = line.strip()
157
+ if line:
158
+ predictions.append(json.loads(line))
159
+ return predictions
src/data/maneuver_classifier.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Classify detected satellite maneuvers into avoidance vs routine.
2
+
3
+ Enriches each maneuver with:
4
+ - magnitude_class: micro/small/medium/large based on delta-v
5
+ - constellation: starlink/oneweb/iridium/other
6
+ - is_stationkeeping: regularity-based detection from maneuver history
7
+ - likely_avoidance: heuristic combining all signals
8
+
9
+ These enrichments improve training label quality for PI-TFT fine-tuning
10
+ without changing the model's feature space.
11
+ """
12
+
13
+ import re
14
+ import numpy as np
15
+ from datetime import datetime
16
+
17
+
18
+ # Delta-v magnitude bins (m/s)
19
+ MAGNITUDE_BINS = [
20
+ ("micro", 0.0, 0.5),
21
+ ("small", 0.5, 2.0),
22
+ ("medium", 2.0, 10.0),
23
+ ("large", 10.0, float("inf")),
24
+ ]
25
+
26
+ # Constellation name patterns
27
+ CONSTELLATION_PATTERNS = [
28
+ ("starlink", re.compile(r"STARLINK", re.IGNORECASE)),
29
+ ("oneweb", re.compile(r"ONEWEB", re.IGNORECASE)),
30
+ ("iridium", re.compile(r"IRIDIUM", re.IGNORECASE)),
31
+ ]
32
+
33
+ # Stationkeeping regularity threshold (coefficient of variation of intervals)
34
+ STATIONKEEPING_CV_THRESHOLD = 0.3
35
+ MIN_HISTORY_FOR_SK = 3 # Need at least 3 past maneuvers to detect pattern
36
+
37
+
38
+ def classify_magnitude(delta_v_m_s: float) -> str:
39
+ """Bin delta-v into magnitude class."""
40
+ dv = abs(delta_v_m_s)
41
+ for label, lo, hi in MAGNITUDE_BINS:
42
+ if lo <= dv < hi:
43
+ return label
44
+ return "large"
45
+
46
+
47
+ def detect_constellation(name: str) -> str:
48
+ """Identify constellation from satellite name."""
49
+ for constellation, pattern in CONSTELLATION_PATTERNS:
50
+ if pattern.search(name):
51
+ return constellation
52
+ return "other"
53
+
54
+
55
+ def detect_stationkeeping(history: list[dict]) -> bool:
56
+ """Detect stationkeeping from regularity of past maneuver intervals.
57
+
58
+ If the coefficient of variation (std/mean) of time intervals between
59
+ consecutive maneuvers is below threshold, it's likely stationkeeping.
60
+
61
+ Args:
62
+ history: Past maneuver records for this NORAD ID, each with
63
+ 'detected_at' ISO timestamp.
64
+
65
+ Returns:
66
+ True if maneuver pattern suggests stationkeeping.
67
+ """
68
+ if not history or len(history) < MIN_HISTORY_FOR_SK:
69
+ return False
70
+
71
+ # Parse timestamps and sort
72
+ timestamps = []
73
+ for h in history:
74
+ ts_str = h.get("detected_at", "")
75
+ if not ts_str:
76
+ continue
77
+ try:
78
+ ts = datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
79
+ timestamps.append(ts.timestamp())
80
+ except (ValueError, TypeError):
81
+ continue
82
+
83
+ if len(timestamps) < MIN_HISTORY_FOR_SK:
84
+ return False
85
+
86
+ timestamps.sort()
87
+ intervals = np.diff(timestamps)
88
+
89
+ if len(intervals) < 2:
90
+ return False
91
+
92
+ mean_interval = np.mean(intervals)
93
+ if mean_interval <= 0:
94
+ return False
95
+
96
+ cv = np.std(intervals) / mean_interval
97
+ return cv < STATIONKEEPING_CV_THRESHOLD
98
+
99
+
100
+ def classify_maneuver(maneuver: dict, history: list[dict] = None) -> dict:
101
+ """Classify a detected maneuver with enrichment flags.
102
+
103
+ Args:
104
+ maneuver: Maneuver dict from detect_maneuvers() with keys:
105
+ norad_id, name, delta_v_m_s, delta_a_m, etc.
106
+ history: Past maneuver records for same NORAD ID (optional).
107
+
108
+ Returns:
109
+ Dict with enrichment fields added to the original maneuver.
110
+ """
111
+ delta_v = maneuver.get("delta_v_m_s", 0.0)
112
+ name = maneuver.get("name", "")
113
+
114
+ magnitude_class = classify_magnitude(delta_v)
115
+ constellation = detect_constellation(name)
116
+ is_sk = detect_stationkeeping(history) if history else False
117
+
118
+ # Likely avoidance heuristic
119
+ likely_avoidance = False
120
+
121
+ if not is_sk and magnitude_class in ("micro", "small") and delta_v < 5.0:
122
+ likely_avoidance = True
123
+
124
+ # Starlink CAMs are typically very small (< 1 m/s)
125
+ if constellation == "starlink" and delta_v < 1.0:
126
+ likely_avoidance = True
127
+
128
+ enriched = dict(maneuver)
129
+ enriched.update({
130
+ "magnitude_class": magnitude_class,
131
+ "constellation": constellation,
132
+ "is_stationkeeping": is_sk,
133
+ "likely_avoidance": likely_avoidance,
134
+ "enrichment_version": 1,
135
+ # Phase B/C defaults — overwritten later if data is available
136
+ "has_cdm": False,
137
+ "cdm_pc": None,
138
+ "cdm_miss_distance_km": None,
139
+ "counterfactual_min_distance_km": None,
140
+ "would_have_collided": False,
141
+ "counterfactual_closest_norad": None,
142
+ })
143
+ return enriched
src/data/maneuver_detector.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by Claude Code -- 2026-02-13
2
+ """Detect satellite maneuvers from TLE data changes.
3
+
4
+ Compares successive TLEs for the same satellite. An abrupt change in
5
+ semi-major axis (> threshold) indicates a maneuver — either collision
6
+ avoidance, orbit maintenance, or orbit raising.
7
+
8
+ Based on Kelecy (2007) and Patera & Peterson (2021).
9
+ """
10
+
11
+ import json
12
+ import math
13
+ import numpy as np
14
+ from pathlib import Path
15
+ from datetime import datetime, timedelta, timezone
16
+
17
+
18
+ # Earth parameters (WGS84)
19
+ MU_EARTH = 398600.4418 # km^3/s^2
20
+ EARTH_RADIUS_KM = 6378.137
21
+
22
+ # Maneuver detection thresholds
23
+ DEFAULT_DELTA_A_THRESHOLD_M = 200 # meters — below this is noise
24
+ STARLINK_DELTA_A_THRESHOLD_M = 100 # Starlink maneuvers can be smaller
25
+
26
+
27
+ def mean_motion_to_sma(n_rev_per_day: float) -> float:
28
+ """Convert mean motion (rev/day) to semi-major axis (km)."""
29
+ if n_rev_per_day <= 0:
30
+ return 0.0
31
+ n_rad_per_sec = n_rev_per_day * 2 * math.pi / 86400.0
32
+ return (MU_EARTH / (n_rad_per_sec ** 2)) ** (1.0 / 3.0)
33
+
34
+
35
+ def sma_to_altitude(sma_km: float) -> float:
36
+ """Convert semi-major axis to approximate altitude (km)."""
37
+ return sma_km - EARTH_RADIUS_KM
38
+
39
+
40
+ def parse_tle_epoch(epoch_str: str) -> datetime:
41
+ """Parse a CelesTrak JSON epoch string (ISO 8601 format)."""
42
+ # CelesTrak uses: "2026-02-13T12:00:00.000000"
43
+ for fmt in ("%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d"):
44
+ try:
45
+ return datetime.strptime(epoch_str, fmt)
46
+ except ValueError:
47
+ continue
48
+ raise ValueError(f"Cannot parse epoch: {epoch_str}")
49
+
50
+
51
+ def extract_orbital_elements(tle_json: dict) -> dict:
52
+ """Extract key orbital elements from a CelesTrak JSON TLE entry."""
53
+ norad_id = int(tle_json.get("NORAD_CAT_ID", 0))
54
+ name = tle_json.get("OBJECT_NAME", "UNKNOWN")
55
+ mean_motion = float(tle_json.get("MEAN_MOTION", 0))
56
+ eccentricity = float(tle_json.get("ECCENTRICITY", 0))
57
+ inclination = float(tle_json.get("INCLINATION", 0))
58
+ raan = float(tle_json.get("RA_OF_ASC_NODE", 0))
59
+ epoch_str = tle_json.get("EPOCH", "")
60
+
61
+ sma = mean_motion_to_sma(mean_motion)
62
+ altitude = sma_to_altitude(sma)
63
+
64
+ epoch = None
65
+ if epoch_str:
66
+ try:
67
+ epoch = parse_tle_epoch(epoch_str)
68
+ except ValueError:
69
+ pass
70
+
71
+ return {
72
+ "norad_id": norad_id,
73
+ "name": name,
74
+ "mean_motion": mean_motion,
75
+ "eccentricity": eccentricity,
76
+ "inclination": inclination,
77
+ "raan": raan,
78
+ "sma_km": sma,
79
+ "altitude_km": altitude,
80
+ "epoch": epoch,
81
+ "epoch_str": epoch_str,
82
+ }
83
+
84
+
85
+ def detect_maneuvers(
86
+ prev_tles: list[dict],
87
+ curr_tles: list[dict],
88
+ threshold_m: float = DEFAULT_DELTA_A_THRESHOLD_M,
89
+ ) -> list[dict]:
90
+ """Compare two TLE snapshots and detect maneuvers.
91
+
92
+ Args:
93
+ prev_tles: Previous TLE snapshot (CelesTrak JSON format)
94
+ curr_tles: Current TLE snapshot (CelesTrak JSON format)
95
+ threshold_m: Semi-major axis change threshold in meters
96
+
97
+ Returns:
98
+ List of detected maneuvers with details
99
+ """
100
+ # Index previous TLEs by NORAD ID
101
+ prev_by_id = {}
102
+ for tle in prev_tles:
103
+ elem = extract_orbital_elements(tle)
104
+ if elem["norad_id"] > 0 and elem["sma_km"] > 0:
105
+ prev_by_id[elem["norad_id"]] = elem
106
+
107
+ maneuvers = []
108
+ for tle in curr_tles:
109
+ elem = extract_orbital_elements(tle)
110
+ norad_id = elem["norad_id"]
111
+
112
+ if norad_id not in prev_by_id or elem["sma_km"] <= 0:
113
+ continue
114
+
115
+ prev = prev_by_id[norad_id]
116
+ delta_a_km = elem["sma_km"] - prev["sma_km"]
117
+ delta_a_m = abs(delta_a_km) * 1000
118
+
119
+ if delta_a_m > threshold_m:
120
+ # Classify maneuver type
121
+ if delta_a_km > 0:
122
+ maneuver_type = "orbit_raise"
123
+ else:
124
+ maneuver_type = "orbit_lower"
125
+
126
+ # Estimate delta-v (Hohmann approximation)
127
+ v_circular = math.sqrt(MU_EARTH / prev["sma_km"]) # km/s
128
+ delta_v = abs(delta_a_km) / (2 * prev["sma_km"]) * v_circular * 1000 # m/s
129
+
130
+ maneuvers.append({
131
+ "norad_id": norad_id,
132
+ "name": elem["name"],
133
+ "prev_sma_km": prev["sma_km"],
134
+ "curr_sma_km": elem["sma_km"],
135
+ "delta_a_m": delta_a_m,
136
+ "delta_a_km": delta_a_km,
137
+ "delta_v_m_s": round(delta_v, 3),
138
+ "maneuver_type": maneuver_type,
139
+ "altitude_km": elem["altitude_km"],
140
+ "prev_epoch": prev["epoch_str"],
141
+ "curr_epoch": elem["epoch_str"],
142
+ "detected_at": datetime.now(timezone.utc).isoformat(),
143
+ })
144
+
145
+ # Sort by delta_a descending (largest maneuvers first)
146
+ maneuvers.sort(key=lambda m: m["delta_a_m"], reverse=True)
147
+ return maneuvers
148
+
149
+
150
+ def detect_maneuvers_dual_threshold(
151
+ prev_tles: list[dict],
152
+ curr_tles: list[dict],
153
+ ) -> list[dict]:
154
+ """Detect maneuvers using constellation-aware thresholds.
155
+
156
+ Uses 100m threshold for Starlink (smaller maneuvers) and
157
+ 200m for everything else. Merges results, deduplicating by NORAD ID.
158
+ """
159
+ # Split current TLEs by constellation
160
+ starlink_curr = []
161
+ other_curr = []
162
+ for tle in curr_tles:
163
+ name = tle.get("OBJECT_NAME", "")
164
+ if "STARLINK" in name.upper():
165
+ starlink_curr.append(tle)
166
+ else:
167
+ other_curr.append(tle)
168
+
169
+ # Split previous TLEs the same way
170
+ starlink_prev = []
171
+ other_prev = []
172
+ for tle in prev_tles:
173
+ name = tle.get("OBJECT_NAME", "")
174
+ if "STARLINK" in name.upper():
175
+ starlink_prev.append(tle)
176
+ else:
177
+ other_prev.append(tle)
178
+
179
+ # Detect with appropriate thresholds
180
+ starlink_maneuvers = detect_maneuvers(
181
+ starlink_prev, starlink_curr,
182
+ threshold_m=STARLINK_DELTA_A_THRESHOLD_M,
183
+ )
184
+ other_maneuvers = detect_maneuvers(
185
+ other_prev, other_curr,
186
+ threshold_m=DEFAULT_DELTA_A_THRESHOLD_M,
187
+ )
188
+
189
+ # Merge and sort by delta_a descending
190
+ all_maneuvers = starlink_maneuvers + other_maneuvers
191
+ all_maneuvers.sort(key=lambda m: m["delta_a_m"], reverse=True)
192
+ return all_maneuvers
193
+
194
+
195
+ def load_tle_snapshot(path: Path) -> list[dict]:
196
+ """Load a TLE snapshot from a JSON file."""
197
+ with open(path) as f:
198
+ return json.load(f)
199
+
200
+
201
+ def save_tle_snapshot(tles: list[dict], path: Path):
202
+ """Save a TLE snapshot to a JSON file."""
203
+ path.parent.mkdir(parents=True, exist_ok=True)
204
+ with open(path, "w") as f:
205
+ json.dump(tles, f)
src/data/merge_sources.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by Claude Code -- 2026-02-08
2
+ """Merge CDM data from multiple sources into unified training format.
3
+
4
+ Combines:
5
+ 1. ESA Kelvins dataset (103 features, labeled)
6
+ 2. Space-Track cdm_public (16 features, unlabeled — derive risk from PC)
7
+
8
+ Strategy:
9
+ - Space-Track CDMs are grouped into "conjunction events" by (SAT_1_ID, SAT_2_ID, TCA_date)
10
+ - Each event gets a time series of CDMs ordered by CREATED date
11
+ - Risk label derived from final PC: high risk if PC > 1e-5 (same threshold as Kelvins)
12
+ - Features that exist in both sources get unified column names
13
+ - Missing features (e.g., covariance in Space-Track) are filled with 0
14
+
15
+ This gives us far more positive examples for training the risk classifier,
16
+ even though the Space-Track data has fewer features per CDM.
17
+ """
18
+
19
+ import numpy as np
20
+ import pandas as pd
21
+ from pathlib import Path
22
+ from datetime import timedelta
23
+
24
+
25
+ # Mapping from Space-Track CDM_PUBLIC fields → unified column names
26
+ SPACETRACK_COLUMN_MAP = {
27
+ "CDM_ID": "cdm_id",
28
+ "CREATED": "created",
29
+ "TCA": "tca",
30
+ "MIN_RNG": "miss_distance", # km in Space-Track
31
+ "PC": "collision_probability",
32
+ "SAT_1_ID": "sat_1_id",
33
+ "SAT_1_NAME": "sat_1_name",
34
+ "SAT1_OBJECT_TYPE": "t_object_type",
35
+ "SAT1_RCS": "t_rcs",
36
+ "SAT_1_EXCL_VOL": "t_excl_vol",
37
+ "SAT_2_ID": "sat_2_id",
38
+ "SAT_2_NAME": "sat_2_name",
39
+ "SAT2_OBJECT_TYPE": "c_object_type",
40
+ "SAT2_RCS": "c_rcs",
41
+ "SAT_2_EXCL_VOL": "c_excl_vol",
42
+ "EMERGENCY_REPORTABLE": "emergency_reportable",
43
+ }
44
+
45
+ # Risk threshold: PC > 1e-5 = high risk (matches ESA Kelvins: risk > -5)
46
+ RISK_THRESHOLD = 1e-5
47
+
48
+
49
+ def load_spacetrack_cdms(csv_path: Path) -> pd.DataFrame:
50
+ """Load Space-Track CDM CSV and do initial cleaning."""
51
+ df = pd.read_csv(csv_path)
52
+
53
+ # Rename columns to unified format
54
+ df = df.rename(columns=SPACETRACK_COLUMN_MAP)
55
+
56
+ # Parse dates
57
+ for col in ["created", "tca"]:
58
+ if col in df.columns:
59
+ df[col] = pd.to_datetime(df[col], errors="coerce")
60
+
61
+ # Convert miss_distance to float
62
+ if "miss_distance" in df.columns:
63
+ df["miss_distance"] = pd.to_numeric(df["miss_distance"], errors="coerce")
64
+ # Space-Track MIN_RNG is in km; ESA Kelvins miss_distance is in meters
65
+ # Convert to meters for consistency
66
+ df["miss_distance"] = df["miss_distance"] * 1000.0
67
+
68
+ # Convert collision_probability to float
69
+ if "collision_probability" in df.columns:
70
+ df["collision_probability"] = pd.to_numeric(df["collision_probability"], errors="coerce")
71
+
72
+ # Derive risk column (log10 of PC, matching ESA format)
73
+ if "collision_probability" in df.columns:
74
+ df["risk"] = np.where(
75
+ df["collision_probability"] > 0,
76
+ np.log10(df["collision_probability"].clip(lower=1e-30)),
77
+ -30.0,
78
+ )
79
+
80
+ print(f"Loaded {len(df)} Space-Track CDMs from {csv_path.name}")
81
+ return df
82
+
83
+
84
+ def group_into_events(df: pd.DataFrame) -> pd.DataFrame:
85
+ """
86
+ Group Space-Track CDMs into conjunction events.
87
+
88
+ An 'event' is a sequence of CDMs for the same object pair with TCA
89
+ values within 1 day of each other. Each event gets a unique event_id.
90
+ """
91
+ if df.empty:
92
+ return df
93
+
94
+ # Sort by object pair and TCA
95
+ df = df.sort_values(["sat_1_id", "sat_2_id", "tca", "created"]).reset_index(drop=True)
96
+
97
+ # Assign event IDs: same pair + TCA within 1 day = same event
98
+ event_ids = []
99
+ current_event = 0
100
+ prev_sat1 = None
101
+ prev_sat2 = None
102
+ prev_tca = None
103
+
104
+ for _, row in df.iterrows():
105
+ sat1 = row.get("sat_1_id")
106
+ sat2 = row.get("sat_2_id")
107
+ tca = row.get("tca")
108
+
109
+ same_pair = (sat1 == prev_sat1 and sat2 == prev_sat2)
110
+ close_tca = False
111
+ if same_pair and prev_tca is not None and pd.notna(tca) and pd.notna(prev_tca):
112
+ close_tca = abs((tca - prev_tca).total_seconds()) < 86400 # 1 day
113
+
114
+ if not (same_pair and close_tca):
115
+ current_event += 1
116
+
117
+ event_ids.append(current_event)
118
+ prev_sat1 = sat1
119
+ prev_sat2 = sat2
120
+ prev_tca = tca
121
+
122
+ df["event_id"] = event_ids
123
+
124
+ # Compute time_to_tca: days from CDM creation to TCA (for each CDM in event)
125
+ if "created" in df.columns and "tca" in df.columns:
126
+ df["time_to_tca"] = (df["tca"] - df["created"]).dt.total_seconds() / 86400.0
127
+ df["time_to_tca"] = df["time_to_tca"].clip(lower=0.0)
128
+
129
+ n_events = df["event_id"].nunique()
130
+ n_high_risk = 0
131
+ if "risk" in df.columns:
132
+ event_risks = df.groupby("event_id")["risk"].last()
133
+ n_high_risk = (event_risks > -5).sum()
134
+
135
+ print(f"Grouped into {n_events} events ({n_high_risk} high-risk)")
136
+ return df
137
+
138
+
139
+ def compute_relative_speed_from_excl_vol(df: pd.DataFrame) -> pd.DataFrame:
140
+ """Estimate relative speed from exclusion volumes if available."""
141
+ # excl_vol is in km, but we can't derive speed from it alone
142
+ # Just ensure the column exists for compatibility
143
+ if "relative_speed" not in df.columns:
144
+ df["relative_speed"] = 0.0
145
+ return df
146
+
147
+
148
+ def align_with_kelvins_schema(
149
+ spacetrack_df: pd.DataFrame,
150
+ kelvins_df: pd.DataFrame,
151
+ ) -> pd.DataFrame:
152
+ """
153
+ Align Space-Track data columns with Kelvins schema.
154
+ Missing columns get filled with 0.
155
+ """
156
+ # Get all columns from Kelvins
157
+ kelvins_cols = set(kelvins_df.columns)
158
+ st_cols = set(spacetrack_df.columns)
159
+
160
+ # Add missing numeric columns as 0
161
+ for col in kelvins_cols:
162
+ if col not in st_cols:
163
+ spacetrack_df[col] = 0.0
164
+
165
+ # Keep only columns that exist in Kelvins + our extra metadata
166
+ extra_cols = {"sat_1_id", "sat_2_id", "sat_1_name", "sat_2_name",
167
+ "t_object_type", "collision_probability", "created", "tca",
168
+ "cdm_id", "emergency_reportable", "t_rcs", "c_rcs",
169
+ "t_excl_vol", "c_excl_vol", "source"}
170
+ keep_cols = list(kelvins_cols | extra_cols)
171
+ available = [c for c in keep_cols if c in spacetrack_df.columns]
172
+ return spacetrack_df[available]
173
+
174
+
175
+ def merge_datasets(
176
+ kelvins_train_df: pd.DataFrame,
177
+ spacetrack_df: pd.DataFrame,
178
+ offset_event_ids: bool = True,
179
+ ) -> pd.DataFrame:
180
+ """
181
+ Merge Kelvins training data with Space-Track CDMs.
182
+
183
+ Args:
184
+ kelvins_train_df: ESA Kelvins training DataFrame
185
+ spacetrack_df: Space-Track CDMs (already grouped into events)
186
+ offset_event_ids: shift Space-Track event_ids to avoid collisions
187
+
188
+ Returns:
189
+ Combined DataFrame ready for model training
190
+ """
191
+ # Tag sources
192
+ kelvins_train_df = kelvins_train_df.copy()
193
+ kelvins_train_df["source"] = "kelvins"
194
+
195
+ spacetrack_df = spacetrack_df.copy()
196
+ spacetrack_df["source"] = "spacetrack"
197
+
198
+ # Offset Space-Track event IDs to avoid collision with Kelvins IDs
199
+ if offset_event_ids and "event_id" in kelvins_train_df.columns:
200
+ max_kelvins_id = kelvins_train_df["event_id"].max()
201
+ spacetrack_df["event_id"] = spacetrack_df["event_id"] + max_kelvins_id + 1
202
+
203
+ # Align columns
204
+ spacetrack_df = align_with_kelvins_schema(spacetrack_df, kelvins_train_df)
205
+
206
+ # Concatenate
207
+ combined = pd.concat([kelvins_train_df, spacetrack_df], ignore_index=True)
208
+
209
+ # Fill any remaining NaN
210
+ numeric_cols = combined.select_dtypes(include=[np.number]).columns
211
+ combined[numeric_cols] = combined[numeric_cols].fillna(0)
212
+
213
+ n_kelvins = kelvins_train_df["event_id"].nunique()
214
+ n_st = spacetrack_df["event_id"].nunique()
215
+ n_total = combined["event_id"].nunique()
216
+
217
+ # Count high-risk events per source
218
+ event_risk = combined.groupby(["event_id", "source"])["risk"].last().reset_index()
219
+ n_hr_kelvins = ((event_risk["source"] == "kelvins") & (event_risk["risk"] > -5)).sum()
220
+ n_hr_st = ((event_risk["source"] == "spacetrack") & (event_risk["risk"] > -5)).sum()
221
+
222
+ print(f"\nMerged dataset:")
223
+ print(f" Kelvins: {n_kelvins} events ({n_hr_kelvins} high-risk)")
224
+ print(f" Space-Track: {n_st} events ({n_hr_st} high-risk)")
225
+ print(f" Total: {n_total} events ({n_hr_kelvins + n_hr_st} high-risk)")
226
+ print(f" Columns: {len(combined.columns)}")
227
+
228
+ return combined
229
+
230
+
231
+ def load_and_merge_all(data_dir: Path) -> tuple[pd.DataFrame, pd.DataFrame]:
232
+ """
233
+ Load all available data sources and merge into train/test DataFrames.
234
+
235
+ Returns (train_df, test_df) — test is Kelvins-only (for fair comparison).
236
+ """
237
+ from src.data.cdm_loader import load_dataset
238
+
239
+ # Load ESA Kelvins
240
+ kelvins_dir = data_dir / "cdm"
241
+ kelvins_train, kelvins_test = load_dataset(kelvins_dir)
242
+
243
+ # Load Space-Track data if available
244
+ spacetrack_dir = data_dir / "cdm_spacetrack"
245
+ spacetrack_files = list(spacetrack_dir.glob("cdm_*.csv")) if spacetrack_dir.exists() else []
246
+
247
+ if not spacetrack_files:
248
+ print("\nNo Space-Track data found. Using Kelvins only.")
249
+ return kelvins_train, kelvins_test
250
+
251
+ # Load and merge all Space-Track CSVs
252
+ st_dfs = []
253
+ for f in spacetrack_files:
254
+ if f.name.startswith("checkpoint"):
255
+ continue
256
+ df = load_spacetrack_cdms(f)
257
+ df = group_into_events(df)
258
+ df = compute_relative_speed_from_excl_vol(df)
259
+ st_dfs.append(df)
260
+
261
+ if st_dfs:
262
+ all_st = pd.concat(st_dfs, ignore_index=True)
263
+ # Re-assign event IDs after concatenation
264
+ all_st = group_into_events(all_st)
265
+ merged_train = merge_datasets(kelvins_train, all_st)
266
+ else:
267
+ merged_train = kelvins_train
268
+
269
+ # Test set stays Kelvins-only for fair benchmarking
270
+ return merged_train, kelvins_test
src/data/sequence_builder.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by Claude Code -- 2026-02-08
2
+ """Build padded CDM sequences for the Temporal Fusion Transformer.
3
+
4
+ Each conjunction event is a variable-length time series of CDM snapshots.
5
+ This module handles:
6
+ - Selecting temporal vs static features
7
+ - Padding/truncating to fixed length
8
+ - Creating attention masks for padded positions
9
+ - Train/val/test splitting with stratification
10
+ """
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+ import torch
15
+ from torch.utils.data import Dataset
16
+ from sklearn.model_selection import train_test_split
17
+ from pathlib import Path
18
+
19
+ # Maximum CDM sequence length (95th percentile of real data is ~25)
20
+ MAX_SEQ_LEN = 30
21
+
22
+ # Features that change with each CDM update (time-varying)
23
+ TEMPORAL_FEATURES = [
24
+ "miss_distance",
25
+ "relative_speed",
26
+ "relative_position_r", "relative_position_t", "relative_position_n",
27
+ "relative_velocity_r", "relative_velocity_t", "relative_velocity_n",
28
+ "max_risk_estimate", "max_risk_scaling",
29
+ # Target object covariance
30
+ "t_sigma_r", "t_sigma_t", "t_sigma_n",
31
+ "t_sigma_rdot", "t_sigma_tdot", "t_sigma_ndot",
32
+ # Chaser object covariance
33
+ "c_sigma_r", "c_sigma_t", "c_sigma_n",
34
+ "c_sigma_rdot", "c_sigma_tdot", "c_sigma_ndot",
35
+ ]
36
+
37
+ # Features that are constant per event (object properties)
38
+ STATIC_FEATURES = [
39
+ "t_h_apo", "t_h_per", "t_j2k_sma", "t_j2k_inc", "t_ecc",
40
+ "c_h_apo", "c_h_per", "c_j2k_sma", "c_j2k_inc", "c_ecc",
41
+ "t_span", "c_span",
42
+ ]
43
+
44
+ # Orbital density features from CRASH Clock analysis (added by OrbitalDensityComputer)
45
+ DENSITY_FEATURES = [
46
+ "shell_density",
47
+ "shell_collision_rate",
48
+ "local_crash_clock_log",
49
+ "altitude_percentile",
50
+ "n_events_in_shell",
51
+ "shell_risk_rate",
52
+ ]
53
+
54
+
55
+ def find_available_features(df: pd.DataFrame, candidates: list[str]) -> list[str]:
56
+ """Filter feature list to only columns that exist in the DataFrame."""
57
+ available = [c for c in candidates if c in df.columns]
58
+ missing = [c for c in candidates if c not in df.columns]
59
+ if missing:
60
+ print(f" Note: {len(missing)} features not in dataset, using {len(available)}")
61
+ return available
62
+
63
+
64
+ class CDMSequenceDataset(Dataset):
65
+ """
66
+ PyTorch Dataset that serves padded CDM sequences for the Transformer.
67
+
68
+ Each item contains:
69
+ - temporal_features: (S, F_t) tensor of time-varying CDM features
70
+ - static_features: (F_s,) tensor of object properties
71
+ - time_to_tca: (S, 1) tensor of time-to-closest-approach values
72
+ - mask: (S,) boolean mask (True = real data, False = padding)
73
+ - risk_label: scalar binary target
74
+ - miss_distance_log: scalar log1p(final_miss_distance) target
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ df: pd.DataFrame,
80
+ max_seq_len: int = MAX_SEQ_LEN,
81
+ temporal_cols: list[str] = None,
82
+ static_cols: list[str] = None,
83
+ ):
84
+ self.max_seq_len = max_seq_len
85
+
86
+ # Find available features
87
+ self.temporal_cols = temporal_cols or find_available_features(df, TEMPORAL_FEATURES)
88
+ self.static_cols = static_cols or find_available_features(df, STATIC_FEATURES)
89
+
90
+ print(f" Temporal features: {len(self.temporal_cols)}")
91
+ print(f" Static features: {len(self.static_cols)}")
92
+
93
+ # Group by event_id
94
+ self.events = []
95
+ for event_id, group in df.groupby("event_id"):
96
+ # Sort by time_to_tca descending (first CDM = furthest from TCA)
97
+ group = group.sort_values("time_to_tca", ascending=False)
98
+ # Track data source for domain weighting
99
+ source = "kelvins"
100
+ if "source" in group.columns:
101
+ source = group["source"].iloc[0]
102
+ self.events.append({
103
+ "event_id": event_id,
104
+ "group": group,
105
+ "source": source,
106
+ })
107
+
108
+ # Compute global normalization stats from training data
109
+ self.temporal_mean = df[self.temporal_cols].mean().values.astype(np.float32)
110
+ self.temporal_std = df[self.temporal_cols].std().values.astype(np.float32)
111
+ self.temporal_std[self.temporal_std < 1e-8] = 1.0 # avoid div by zero
112
+
113
+ self.static_mean = df[self.static_cols].mean().values.astype(np.float32)
114
+ self.static_std = df[self.static_cols].std().values.astype(np.float32)
115
+ self.static_std[self.static_std < 1e-8] = 1.0
116
+
117
+ # Normalize time_to_tca
118
+ self.tca_mean = float(df["time_to_tca"].mean())
119
+ self.tca_std = float(df["time_to_tca"].std())
120
+ if self.tca_std < 1e-8:
121
+ self.tca_std = 1.0
122
+
123
+ # Compute delta normalization stats (approx from per-step differences)
124
+ # Deltas have different magnitude than raw features, need separate stats
125
+ self._compute_delta_stats(df)
126
+
127
+ def _compute_delta_stats(self, df: pd.DataFrame):
128
+ """Estimate normalization stats for temporal first-order differences."""
129
+ # Sample a subset of events to estimate delta distributions
130
+ delta_samples = []
131
+ for _, group in df.groupby("event_id"):
132
+ if len(group) < 2:
133
+ continue
134
+ vals = group[self.temporal_cols].values.astype(np.float32)
135
+ vals = np.nan_to_num(vals, nan=0.0, posinf=0.0, neginf=0.0)
136
+ deltas = np.diff(vals, axis=0)
137
+ delta_samples.append(deltas)
138
+ if len(delta_samples) >= 2000: # cap for speed
139
+ break
140
+ if delta_samples:
141
+ all_deltas = np.concatenate(delta_samples, axis=0)
142
+ self.delta_mean = all_deltas.mean(axis=0).astype(np.float32)
143
+ self.delta_std = all_deltas.std(axis=0).astype(np.float32)
144
+ self.delta_std[self.delta_std < 1e-8] = 1.0
145
+ else:
146
+ n = len(self.temporal_cols)
147
+ self.delta_mean = np.zeros(n, dtype=np.float32)
148
+ self.delta_std = np.ones(n, dtype=np.float32)
149
+
150
+ def set_normalization(self, other: "CDMSequenceDataset"):
151
+ """Copy normalization stats from another dataset (e.g., training set)."""
152
+ self.temporal_mean = other.temporal_mean
153
+ self.temporal_std = other.temporal_std
154
+ self.static_mean = other.static_mean
155
+ self.static_std = other.static_std
156
+ self.tca_mean = other.tca_mean
157
+ self.tca_std = other.tca_std
158
+ self.delta_mean = other.delta_mean
159
+ self.delta_std = other.delta_std
160
+
161
+ def __len__(self):
162
+ return len(self.events)
163
+
164
+ def __getitem__(self, idx):
165
+ event = self.events[idx]
166
+ group = event["group"]
167
+
168
+ # Extract temporal features: (seq_len, n_temporal)
169
+ temporal = group[self.temporal_cols].values.astype(np.float32)
170
+ temporal = np.nan_to_num(temporal, nan=0.0, posinf=0.0, neginf=0.0)
171
+
172
+ # Compute first-order differences (deltas) for temporal features
173
+ # This captures trends: is miss_distance shrinking? Is covariance tightening?
174
+ if len(temporal) > 1:
175
+ deltas = np.diff(temporal, axis=0) # (seq_len-1, n_temporal)
176
+ # Prepend zeros for the first timestep (no prior to diff against)
177
+ deltas = np.concatenate([np.zeros((1, deltas.shape[1]), dtype=np.float32), deltas], axis=0)
178
+ else:
179
+ deltas = np.zeros_like(temporal)
180
+
181
+ # Normalize raw features and deltas separately
182
+ temporal = (temporal - self.temporal_mean) / self.temporal_std
183
+ deltas = (deltas - self.delta_mean) / self.delta_std
184
+
185
+ # Concatenate: (seq_len, n_temporal * 2)
186
+ temporal = np.concatenate([temporal, deltas], axis=1)
187
+
188
+ # Extract static features from last row (they're constant per event)
189
+ static = group[self.static_cols].iloc[-1].values.astype(np.float32)
190
+ static = np.nan_to_num(static, nan=0.0, posinf=0.0, neginf=0.0)
191
+
192
+ # Time-to-TCA values: (seq_len, 1)
193
+ tca = group["time_to_tca"].values.astype(np.float32).reshape(-1, 1)
194
+
195
+ # Normalize
196
+ static = (static - self.static_mean) / self.static_std
197
+ tca = (tca - self.tca_mean) / self.tca_std
198
+
199
+ # Truncate or pad to max_seq_len
200
+ seq_len = len(temporal)
201
+ if seq_len > self.max_seq_len:
202
+ # Keep the most recent CDMs (closest to TCA = most informative)
203
+ temporal = temporal[-self.max_seq_len:]
204
+ tca = tca[-self.max_seq_len:]
205
+ seq_len = self.max_seq_len
206
+
207
+ # Pad (left-pad so the most recent CDM is always at position -1)
208
+ pad_len = self.max_seq_len - seq_len
209
+ if pad_len > 0:
210
+ temporal = np.pad(temporal, ((pad_len, 0), (0, 0)), constant_values=0)
211
+ tca = np.pad(tca, ((pad_len, 0), (0, 0)), constant_values=0)
212
+
213
+ # Attention mask: True for real positions, False for padding
214
+ mask = np.zeros(self.max_seq_len, dtype=bool)
215
+ mask[pad_len:] = True
216
+
217
+ # Target: risk label from final CDM's risk column
218
+ # risk > -5 means collision probability > 1e-5 (high risk)
219
+ final_risk = group["risk"].iloc[-1]
220
+ risk_label = 1.0 if final_risk > -5 else 0.0
221
+
222
+ # Target: log1p of final miss distance
223
+ final_miss = group["miss_distance"].iloc[-1] if "miss_distance" in group.columns else 0.0
224
+ miss_log = np.log1p(max(final_miss, 0.0))
225
+
226
+ # Target: log10(Pc) — the Kelvins `risk` column is already log10(Pc).
227
+ # Clamp to [-20, 0] (Pc ranges from ~1e-20 to ~1)
228
+ pc_log10 = float(max(min(final_risk, 0.0), -20.0))
229
+
230
+ # Domain weight: Kelvins events get full weight, Space-Track events
231
+ # get reduced weight since they have sparse features (16 vs 103 columns).
232
+ # This prevents the model from learning shortcuts on zero-padded features.
233
+ source = event.get("source", "kelvins")
234
+ domain_weight = 1.0 if source == "kelvins" else 0.3
235
+
236
+ return {
237
+ "temporal": torch.tensor(temporal, dtype=torch.float32),
238
+ "static": torch.tensor(static, dtype=torch.float32),
239
+ "time_to_tca": torch.tensor(tca, dtype=torch.float32),
240
+ "mask": torch.tensor(mask, dtype=torch.bool),
241
+ "risk_label": torch.tensor(risk_label, dtype=torch.float32),
242
+ "miss_log": torch.tensor(miss_log, dtype=torch.float32),
243
+ "pc_log10": torch.tensor(pc_log10, dtype=torch.float32),
244
+ "domain_weight": torch.tensor(domain_weight, dtype=torch.float32),
245
+ }
246
+
247
+
248
+ class PretrainDataset(Dataset):
249
+ """Simplified CDM dataset for self-supervised pre-training (no labels needed).
250
+
251
+ Returns only temporal features, static features, time_to_tca, and mask.
252
+ Can process combined train+test data since labels aren't used.
253
+ """
254
+
255
+ def __init__(
256
+ self,
257
+ df: pd.DataFrame,
258
+ max_seq_len: int = MAX_SEQ_LEN,
259
+ temporal_cols: list[str] = None,
260
+ static_cols: list[str] = None,
261
+ ):
262
+ self.max_seq_len = max_seq_len
263
+
264
+ self.temporal_cols = temporal_cols or find_available_features(df, TEMPORAL_FEATURES)
265
+ self.static_cols = static_cols or find_available_features(df, STATIC_FEATURES)
266
+
267
+ print(f" PretrainDataset — Temporal: {len(self.temporal_cols)}, Static: {len(self.static_cols)}")
268
+
269
+ # Group by event_id
270
+ self.events = []
271
+ for event_id, group in df.groupby("event_id"):
272
+ group = group.sort_values("time_to_tca", ascending=False)
273
+ self.events.append({"event_id": event_id, "group": group})
274
+
275
+ # Compute global normalization stats
276
+ self.temporal_mean = df[self.temporal_cols].mean().values.astype(np.float32)
277
+ self.temporal_std = df[self.temporal_cols].std().values.astype(np.float32)
278
+ self.temporal_std[self.temporal_std < 1e-8] = 1.0
279
+
280
+ self.static_mean = df[self.static_cols].mean().values.astype(np.float32)
281
+ self.static_std = df[self.static_cols].std().values.astype(np.float32)
282
+ self.static_std[self.static_std < 1e-8] = 1.0
283
+
284
+ self.tca_mean = float(df["time_to_tca"].mean())
285
+ self.tca_std = float(df["time_to_tca"].std())
286
+ if self.tca_std < 1e-8:
287
+ self.tca_std = 1.0
288
+
289
+ self._compute_delta_stats(df)
290
+
291
+ def _compute_delta_stats(self, df: pd.DataFrame):
292
+ """Estimate normalization stats for temporal first-order differences."""
293
+ delta_samples = []
294
+ for _, group in df.groupby("event_id"):
295
+ if len(group) < 2:
296
+ continue
297
+ vals = group[self.temporal_cols].values.astype(np.float32)
298
+ vals = np.nan_to_num(vals, nan=0.0, posinf=0.0, neginf=0.0)
299
+ deltas = np.diff(vals, axis=0)
300
+ delta_samples.append(deltas)
301
+ if len(delta_samples) >= 2000:
302
+ break
303
+ if delta_samples:
304
+ all_deltas = np.concatenate(delta_samples, axis=0)
305
+ self.delta_mean = all_deltas.mean(axis=0).astype(np.float32)
306
+ self.delta_std = all_deltas.std(axis=0).astype(np.float32)
307
+ self.delta_std[self.delta_std < 1e-8] = 1.0
308
+ else:
309
+ n = len(self.temporal_cols)
310
+ self.delta_mean = np.zeros(n, dtype=np.float32)
311
+ self.delta_std = np.ones(n, dtype=np.float32)
312
+
313
+ def set_normalization(self, other):
314
+ """Copy normalization stats from another dataset."""
315
+ self.temporal_mean = other.temporal_mean
316
+ self.temporal_std = other.temporal_std
317
+ self.static_mean = other.static_mean
318
+ self.static_std = other.static_std
319
+ self.tca_mean = other.tca_mean
320
+ self.tca_std = other.tca_std
321
+ self.delta_mean = other.delta_mean
322
+ self.delta_std = other.delta_std
323
+
324
+ def __len__(self):
325
+ return len(self.events)
326
+
327
+ def __getitem__(self, idx):
328
+ event = self.events[idx]
329
+ group = event["group"]
330
+
331
+ # Extract temporal features
332
+ temporal = group[self.temporal_cols].values.astype(np.float32)
333
+ temporal = np.nan_to_num(temporal, nan=0.0, posinf=0.0, neginf=0.0)
334
+
335
+ # Compute first-order differences
336
+ if len(temporal) > 1:
337
+ deltas = np.diff(temporal, axis=0)
338
+ deltas = np.concatenate([np.zeros((1, deltas.shape[1]), dtype=np.float32), deltas], axis=0)
339
+ else:
340
+ deltas = np.zeros_like(temporal)
341
+
342
+ # Normalize
343
+ temporal = (temporal - self.temporal_mean) / self.temporal_std
344
+ deltas = (deltas - self.delta_mean) / self.delta_std
345
+ temporal = np.concatenate([temporal, deltas], axis=1)
346
+
347
+ # Static features
348
+ static = group[self.static_cols].iloc[-1].values.astype(np.float32)
349
+ static = np.nan_to_num(static, nan=0.0, posinf=0.0, neginf=0.0)
350
+
351
+ # Time-to-TCA
352
+ tca = group["time_to_tca"].values.astype(np.float32).reshape(-1, 1)
353
+
354
+ static = (static - self.static_mean) / self.static_std
355
+ tca = (tca - self.tca_mean) / self.tca_std
356
+
357
+ # Truncate or pad
358
+ seq_len = len(temporal)
359
+ if seq_len > self.max_seq_len:
360
+ temporal = temporal[-self.max_seq_len:]
361
+ tca = tca[-self.max_seq_len:]
362
+ seq_len = self.max_seq_len
363
+
364
+ pad_len = self.max_seq_len - seq_len
365
+ if pad_len > 0:
366
+ temporal = np.pad(temporal, ((pad_len, 0), (0, 0)), constant_values=0)
367
+ tca = np.pad(tca, ((pad_len, 0), (0, 0)), constant_values=0)
368
+
369
+ mask = np.zeros(self.max_seq_len, dtype=bool)
370
+ mask[pad_len:] = True
371
+
372
+ return {
373
+ "temporal": torch.tensor(temporal, dtype=torch.float32),
374
+ "static": torch.tensor(static, dtype=torch.float32),
375
+ "time_to_tca": torch.tensor(tca, dtype=torch.float32),
376
+ "mask": torch.tensor(mask, dtype=torch.bool),
377
+ }
378
+
379
+
380
+ def build_datasets(
381
+ train_df: pd.DataFrame,
382
+ test_df: pd.DataFrame,
383
+ val_fraction: float = 0.1,
384
+ use_density: bool = False,
385
+ cal_fraction: float = 0.0,
386
+ ) -> tuple:
387
+ """
388
+ Build train, validation, and test datasets with shared normalization.
389
+
390
+ Splits training data into train + val by event_id (stratified by risk).
391
+
392
+ Args:
393
+ train_df: Training CDM DataFrame
394
+ test_df: Test CDM DataFrame
395
+ val_fraction: Fraction of Kelvins training events for validation
396
+ use_density: If True, include DENSITY_FEATURES in static features
397
+ cal_fraction: If > 0, further split validation into val + calibration
398
+ for conformal prediction. Returns 4-tuple instead of 3.
399
+
400
+ Returns:
401
+ If cal_fraction == 0: (train_ds, val_ds, test_ds)
402
+ If cal_fraction > 0: (train_ds, val_ds, cal_ds, test_ds)
403
+ """
404
+ # Compute density features if requested
405
+ if use_density:
406
+ from src.data.density_features import OrbitalDensityComputer
407
+ density_computer = OrbitalDensityComputer()
408
+ density_computer.fit(train_df)
409
+ train_df = density_computer.transform(train_df)
410
+ test_df = density_computer.transform(test_df)
411
+ else:
412
+ density_computer = None
413
+
414
+ # Static columns: base (filtered to available) + optional density
415
+ static_cols = [c for c in STATIC_FEATURES if c in train_df.columns]
416
+ if use_density:
417
+ static_cols = static_cols + [
418
+ f for f in DENSITY_FEATURES if f in train_df.columns
419
+ ]
420
+
421
+ # Determine risk label per event for stratification
422
+ has_source = "source" in train_df.columns
423
+ agg_dict = {"risk": ("risk", "last")}
424
+ if has_source:
425
+ agg_dict["source"] = ("source", "first")
426
+ event_meta = train_df.groupby("event_id").agg(**agg_dict).reset_index()
427
+ event_meta["label"] = (event_meta["risk"] > -5).astype(int)
428
+
429
+ # Split validation from KELVINS-ONLY events for fair model selection.
430
+ # Space-Track events (sparse features, all high-risk) inflate val metrics.
431
+ if has_source:
432
+ kelvins_events = event_meta[event_meta["source"] == "kelvins"]
433
+ other_events = event_meta[event_meta["source"] != "kelvins"]
434
+
435
+ kelvins_ids = kelvins_events["event_id"].values
436
+ kelvins_labels = kelvins_events["label"].values
437
+
438
+ # Stratified split on Kelvins events only
439
+ k_train_ids, val_ids = train_test_split(
440
+ kelvins_ids, test_size=val_fraction, stratify=kelvins_labels, random_state=42
441
+ )
442
+ # Training = Kelvins train split + all Space-Track events
443
+ train_ids = np.concatenate([k_train_ids, other_events["event_id"].values])
444
+ else:
445
+ event_ids = event_meta["event_id"].values
446
+ labels = event_meta["label"].values
447
+ train_ids, val_ids = train_test_split(
448
+ event_ids, test_size=val_fraction, stratify=labels, random_state=42
449
+ )
450
+
451
+ # Further split validation into val + calibration for conformal prediction
452
+ cal_ids = np.array([])
453
+ if cal_fraction > 0 and len(val_ids) > 20:
454
+ val_labels = event_meta[event_meta["event_id"].isin(val_ids)]["label"].values
455
+ val_ids_arr = val_ids
456
+ val_ids, cal_ids = train_test_split(
457
+ val_ids_arr,
458
+ test_size=cal_fraction,
459
+ stratify=val_labels,
460
+ random_state=123, # different seed from train/val split
461
+ )
462
+
463
+ train_sub = train_df[train_df["event_id"].isin(train_ids)]
464
+ val_sub = train_df[train_df["event_id"].isin(val_ids)]
465
+
466
+ print(f"Building datasets:")
467
+ print(f" Train events: {len(train_ids)}")
468
+ if has_source:
469
+ n_k = train_sub[train_sub["source"] == "kelvins"]["event_id"].nunique()
470
+ n_s = train_sub[train_sub["source"] != "kelvins"]["event_id"].nunique()
471
+ print(f" (Kelvins: {n_k}, Space-Track: {n_s})")
472
+ if use_density:
473
+ print(f" Static features: {len(static_cols)} (base: {len(STATIC_FEATURES)}, "
474
+ f"density: {len(static_cols) - len(STATIC_FEATURES)})")
475
+
476
+ train_ds = CDMSequenceDataset(train_sub, static_cols=static_cols)
477
+
478
+ print(f" Val events: {len(val_ids)} (Kelvins-only)")
479
+ val_ds = CDMSequenceDataset(val_sub, static_cols=static_cols)
480
+ val_ds.set_normalization(train_ds) # use training stats
481
+
482
+ print(f" Test events: {test_df['event_id'].nunique()}")
483
+ test_ds = CDMSequenceDataset(test_df, temporal_cols=train_ds.temporal_cols, static_cols=static_cols)
484
+ test_ds.set_normalization(train_ds)
485
+
486
+ # Store density computer on train_ds for checkpoint saving
487
+ if density_computer is not None:
488
+ train_ds._density_computer = density_computer
489
+
490
+ if cal_fraction > 0 and len(cal_ids) > 0:
491
+ cal_sub = train_df[train_df["event_id"].isin(cal_ids)]
492
+ print(f" Cal events: {len(cal_ids)} (for conformal prediction)")
493
+ cal_ds = CDMSequenceDataset(cal_sub, static_cols=static_cols)
494
+ cal_ds.set_normalization(train_ds)
495
+ return train_ds, val_ds, cal_ds, test_ds
496
+
497
+ return train_ds, val_ds, test_ds
src/data/spacetrack_crossref.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cross-reference detected maneuvers with Space-Track.org CDM data.
2
+
3
+ Queries the CDM_PUBLIC class for recent conjunction data messages
4
+ involving maneuvered satellites. CDM confirmation is the strongest
5
+ signal that a maneuver was collision-avoidance.
6
+
7
+ Requires SPACETRACK_USER and SPACETRACK_PASS environment variables.
8
+ Fails silently if credentials are not set (purely enrichment).
9
+ """
10
+
11
+ import os
12
+ import json
13
+ import time
14
+ import requests
15
+ from pathlib import Path
16
+ from datetime import datetime, timedelta, timezone
17
+
18
+ # Rate limiting: max 30 requests/min to Space-Track
19
+ MAX_REQUESTS_PER_MIN = 30
20
+ BATCH_SIZE = 100 # Max NORAD IDs per query
21
+ CACHE_EXPIRY_DAYS = 7
22
+
23
+ SPACETRACK_BASE = "https://www.space-track.org"
24
+ LOGIN_URL = f"{SPACETRACK_BASE}/ajaxauth/login"
25
+ CDM_QUERY_URL = f"{SPACETRACK_BASE}/basicspacedata/query/class/cdm_public"
26
+
27
+
28
+ def _get_credentials() -> tuple[str, str]:
29
+ """Get Space-Track credentials from environment."""
30
+ user = os.environ.get("SPACETRACK_USER", "")
31
+ passwd = os.environ.get("SPACETRACK_PASS", "")
32
+ return user, passwd
33
+
34
+
35
+ def _load_cache(cache_path: Path) -> dict:
36
+ """Load CDM cache, filtering expired entries."""
37
+ if not cache_path.exists():
38
+ return {}
39
+
40
+ try:
41
+ with open(cache_path) as f:
42
+ cache = json.load(f)
43
+ except (json.JSONDecodeError, IOError):
44
+ return {}
45
+
46
+ # Filter expired entries
47
+ cutoff = (datetime.now(timezone.utc) - timedelta(days=CACHE_EXPIRY_DAYS)).isoformat()
48
+ return {
49
+ k: v for k, v in cache.items()
50
+ if v.get("cached_at", "") > cutoff
51
+ }
52
+
53
+
54
+ def _save_cache(cache: dict, cache_path: Path):
55
+ """Save CDM cache to disk."""
56
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
57
+ with open(cache_path, "w") as f:
58
+ json.dump(cache, f, indent=2)
59
+
60
+
61
+ def check_cdm_for_norad_ids(
62
+ norad_ids: list[int],
63
+ lookback_days: int = 7,
64
+ min_pc: float = 1e-7,
65
+ cache_dir: Path = None,
66
+ ) -> dict[int, list[dict]]:
67
+ """Query Space-Track CDM_PUBLIC for recent CDMs involving given satellites.
68
+
69
+ Args:
70
+ norad_ids: NORAD catalog IDs to check.
71
+ lookback_days: How far back to search for CDMs.
72
+ min_pc: Minimum probability of collision to include.
73
+ cache_dir: Directory for CDM cache file. Defaults to data/prediction_logs/.
74
+
75
+ Returns:
76
+ Map of norad_id -> list of CDM records with PC, TCA, MISS_DISTANCE.
77
+ Empty dict if credentials not set or query fails.
78
+ """
79
+ user, passwd = _get_credentials()
80
+ if not user or not passwd:
81
+ return {}
82
+
83
+ if cache_dir is None:
84
+ cache_dir = Path(__file__).parent.parent.parent / "data" / "prediction_logs"
85
+
86
+ cache_path = cache_dir / "cdm_cache.json"
87
+ cache = _load_cache(cache_path)
88
+
89
+ # Check which IDs need fresh queries
90
+ results = {}
91
+ uncached_ids = []
92
+
93
+ for nid in norad_ids:
94
+ key = str(nid)
95
+ if key in cache:
96
+ results[nid] = cache[key].get("cdms", [])
97
+ else:
98
+ uncached_ids.append(nid)
99
+
100
+ if not uncached_ids:
101
+ return results
102
+
103
+ # Authenticate with Space-Track
104
+ try:
105
+ session = requests.Session()
106
+ resp = session.post(LOGIN_URL, data={
107
+ "identity": user,
108
+ "password": passwd,
109
+ }, timeout=30)
110
+ resp.raise_for_status()
111
+ except Exception as e:
112
+ print(f" Space-Track login failed: {e}")
113
+ return results
114
+
115
+ # Query in batches
116
+ now_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
117
+ lookback_str = (datetime.now(timezone.utc) - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
118
+
119
+ for batch_start in range(0, len(uncached_ids), BATCH_SIZE):
120
+ batch = uncached_ids[batch_start:batch_start + BATCH_SIZE]
121
+ ids_str = ",".join(str(nid) for nid in batch)
122
+
123
+ query_url = (
124
+ f"{CDM_QUERY_URL}"
125
+ f"/SAT1_NORAD_CAT_ID/{ids_str}"
126
+ f"/TCA/>{lookback_str}"
127
+ f"/orderby/TCA desc"
128
+ f"/format/json"
129
+ )
130
+
131
+ try:
132
+ resp = session.get(query_url, timeout=60)
133
+ resp.raise_for_status()
134
+ cdm_records = resp.json()
135
+ except Exception as e:
136
+ print(f" Space-Track CDM query failed: {e}")
137
+ # Cache empty results for failed IDs to avoid re-querying
138
+ for nid in batch:
139
+ cache[str(nid)] = {
140
+ "cdms": [],
141
+ "cached_at": datetime.now(timezone.utc).isoformat(),
142
+ }
143
+ continue
144
+
145
+ # Process CDM records
146
+ batch_results: dict[int, list[dict]] = {nid: [] for nid in batch}
147
+
148
+ for cdm in cdm_records:
149
+ try:
150
+ pc = float(cdm.get("PC", 0) or 0)
151
+ if pc < min_pc:
152
+ continue
153
+
154
+ sat1_id = int(cdm.get("SAT1_NORAD_CAT_ID", 0))
155
+ record = {
156
+ "tca": cdm.get("TCA", ""),
157
+ "pc": pc,
158
+ "miss_distance_km": float(cdm.get("MISS_DISTANCE", 0) or 0) / 1000.0,
159
+ "sat1_name": cdm.get("SAT1_NAME", ""),
160
+ "sat2_name": cdm.get("SAT2_NAME", ""),
161
+ "sat2_norad": int(cdm.get("SAT2_NORAD_CAT_ID", 0) or 0),
162
+ }
163
+
164
+ if sat1_id in batch_results:
165
+ batch_results[sat1_id].append(record)
166
+ except (ValueError, TypeError):
167
+ continue
168
+
169
+ # Update cache and results
170
+ for nid in batch:
171
+ cdms = batch_results.get(nid, [])
172
+ results[nid] = cdms
173
+ cache[str(nid)] = {
174
+ "cdms": cdms,
175
+ "cached_at": datetime.now(timezone.utc).isoformat(),
176
+ }
177
+
178
+ # Rate limiting between batches
179
+ if batch_start + BATCH_SIZE < len(uncached_ids):
180
+ time.sleep(60.0 / MAX_REQUESTS_PER_MIN)
181
+
182
+ # Save updated cache
183
+ _save_cache(cache, cache_path)
184
+
185
+ return results
src/evaluation/__init__.py ADDED
File without changes
src/evaluation/conformal.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by Claude Code — 2026-02-13
2
+ """Conformal prediction for calibrated risk bounds.
3
+
4
+ Provides distribution-free prediction sets with guaranteed marginal coverage:
5
+ P(true_label ∈ prediction_set) ≥ 1 - alpha
6
+
7
+ This directly addresses NASA CARA's criticism about uncertainty quantification
8
+ in ML-based collision risk assessment. Instead of a single probability, we
9
+ output a prediction set (e.g., {LOW, MODERATE}) that provably covers the
10
+ true risk tier at the specified confidence level.
11
+
12
+ Method: Split conformal prediction (Vovk et al. 2005, Lei et al. 2018)
13
+ - Calibrate on a held-out set separate from training AND model selection
14
+ - Compute nonconformity scores
15
+ - Use quantile of calibration scores to construct prediction sets at test time
16
+
17
+ References:
18
+ - Vovk, Gammerman, Shafer (2005) "Algorithmic Learning in a Random World"
19
+ - Lei et al. (2018) "Distribution-Free Predictive Inference for Regression"
20
+ - Angelopoulos & Bates (2021) "A Gentle Introduction to Conformal Prediction"
21
+ """
22
+
23
+ import numpy as np
24
+ from dataclasses import dataclass
25
+
26
+
27
+ @dataclass
28
+ class ConformalResult:
29
+ """Result of conformal prediction for a single example."""
30
+ prediction_set: list[str] # e.g., ["LOW", "MODERATE"]
31
+ set_size: int # |prediction_set|
32
+ risk_prob: float # raw model probability
33
+ lower_bound: float # lower probability bound
34
+ upper_bound: float # upper probability bound
35
+
36
+
37
+ class ConformalPredictor:
38
+ """Split conformal prediction for binary risk classification.
39
+
40
+ Workflow:
41
+ 1. Train model on training set
42
+ 2. Select model (early stopping) on validation set
43
+ 3. calibrate() on a SEPARATE calibration set (held out from validation)
44
+ 4. predict() on test data with coverage guarantee
45
+
46
+ The calibration set must NOT be used for training or model selection,
47
+ otherwise the coverage guarantee is invalidated.
48
+ """
49
+
50
+ # Risk tiers with thresholds
51
+ TIERS = {
52
+ "LOW": (0.0, 0.10),
53
+ "MODERATE": (0.10, 0.40),
54
+ "HIGH": (0.40, 0.70),
55
+ "CRITICAL": (0.70, 1.0),
56
+ }
57
+
58
+ def __init__(self):
59
+ self.quantile_lower = None # q_hat for lower bound
60
+ self.quantile_upper = None # q_hat for upper bound
61
+ self.alpha = None
62
+ self.n_cal = 0
63
+ self.is_calibrated = False
64
+
65
+ def calibrate(
66
+ self,
67
+ cal_probs: np.ndarray,
68
+ cal_labels: np.ndarray,
69
+ alpha: float = 0.10,
70
+ ) -> dict:
71
+ """Calibrate conformal predictor on held-out calibration set.
72
+
73
+ Args:
74
+ cal_probs: Model predicted probabilities on calibration set, shape (n,)
75
+ cal_labels: True binary labels on calibration set, shape (n,)
76
+ alpha: Desired miscoverage rate. 1-alpha = coverage level.
77
+ alpha=0.10 → 90% coverage guarantee.
78
+
79
+ Returns:
80
+ Calibration summary dict with quantiles and statistics
81
+ """
82
+ n = len(cal_probs)
83
+ if n < 10:
84
+ raise ValueError(f"Calibration set too small: {n} examples (need >= 10)")
85
+
86
+ self.alpha = alpha
87
+ self.n_cal = n
88
+
89
+ # Nonconformity score: how "wrong" is the model on each calibration example?
90
+ # For binary classification with probabilities:
91
+ # score = 1 - P(true class)
92
+ # High score = model is wrong/uncertain
93
+ scores = np.where(
94
+ cal_labels == 1,
95
+ 1.0 - cal_probs, # positive: score = 1 - P(positive)
96
+ cal_probs, # negative: score = P(positive) = 1 - P(negative)
97
+ )
98
+
99
+ # Conformal quantile: includes finite-sample correction
100
+ # q_hat = ceil((n+1)(1-alpha))/n -th quantile of scores
101
+ adjusted_level = np.ceil((n + 1) * (1 - alpha)) / n
102
+ adjusted_level = min(adjusted_level, 1.0)
103
+ self.q_hat = float(np.quantile(scores, adjusted_level))
104
+
105
+ # For prediction intervals on the probability itself:
106
+ # We also compute quantiles for constructing upper/lower prob bounds
107
+ # Using calibration residuals: |P(positive) - is_positive|
108
+ residuals = np.abs(cal_probs - cal_labels.astype(float))
109
+ self.q_residual = float(np.quantile(residuals, adjusted_level))
110
+
111
+ self.is_calibrated = True
112
+
113
+ # Report calibration statistics
114
+ empirical_coverage = np.mean(scores <= self.q_hat)
115
+
116
+ summary = {
117
+ "alpha": alpha,
118
+ "target_coverage": 1 - alpha,
119
+ "n_calibration": n,
120
+ "q_hat": self.q_hat,
121
+ "q_residual": self.q_residual,
122
+ "empirical_coverage_cal": float(empirical_coverage),
123
+ "mean_score": float(scores.mean()),
124
+ "median_score": float(np.median(scores)),
125
+ "cal_pos_rate": float(cal_labels.mean()),
126
+ }
127
+
128
+ print(f" Conformal calibration (alpha={alpha}):")
129
+ print(f" Calibration set: {n} examples ({cal_labels.sum():.0f} positive)")
130
+ print(f" q_hat (nonconformity): {self.q_hat:.4f}")
131
+ print(f" q_residual: {self.q_residual:.4f}")
132
+ print(f" Empirical coverage (cal): {empirical_coverage:.4f}")
133
+
134
+ return summary
135
+
136
+ def predict(self, test_probs: np.ndarray) -> list[ConformalResult]:
137
+ """Produce conformal prediction sets for test examples.
138
+
139
+ For each test example, returns:
140
+ - Prediction set: set of risk tiers that could contain the true risk
141
+ - Probability bounds: [lower, upper] interval on the true probability
142
+
143
+ Coverage guarantee: P(true_tier ∈ prediction_set) ≥ 1 - alpha
144
+ """
145
+ if not self.is_calibrated:
146
+ raise RuntimeError("Must call calibrate() before predict()")
147
+
148
+ results = []
149
+ for p in test_probs:
150
+ # Probability bounds from residual quantile
151
+ lower = max(0.0, p - self.q_residual)
152
+ upper = min(1.0, p + self.q_residual)
153
+
154
+ # Prediction set: all tiers that overlap with [lower, upper]
155
+ pred_set = []
156
+ for tier_name, (tier_lo, tier_hi) in self.TIERS.items():
157
+ if lower < tier_hi and upper > tier_lo:
158
+ pred_set.append(tier_name)
159
+
160
+ results.append(ConformalResult(
161
+ prediction_set=pred_set,
162
+ set_size=len(pred_set),
163
+ risk_prob=float(p),
164
+ lower_bound=lower,
165
+ upper_bound=upper,
166
+ ))
167
+
168
+ return results
169
+
170
+ def evaluate(
171
+ self,
172
+ test_probs: np.ndarray,
173
+ test_labels: np.ndarray,
174
+ ) -> dict:
175
+ """Evaluate conformal prediction on test set.
176
+
177
+ Reports:
178
+ - Marginal coverage: fraction of test examples where true label
179
+ falls within prediction set
180
+ - Average set size: how informative are the predictions
181
+ - Coverage by tier: per-tier coverage (conditional coverage)
182
+ - Efficiency: 1 - (avg_set_size / n_tiers)
183
+ """
184
+ if not self.is_calibrated:
185
+ raise RuntimeError("Must call calibrate() before evaluate()")
186
+
187
+ results = self.predict(test_probs)
188
+
189
+ # Map labels to tiers for coverage check
190
+ def label_to_tier(prob: float) -> str:
191
+ for tier_name, (lo, hi) in self.TIERS.items():
192
+ if lo <= prob < hi:
193
+ return tier_name
194
+ return "CRITICAL" # prob == 1.0
195
+
196
+ # True "tier" based on actual probability (binary: 0 or 1)
197
+ true_tiers = [label_to_tier(float(l)) for l in test_labels]
198
+
199
+ # Marginal coverage: does the prediction set contain the true tier?
200
+ covered = [
201
+ true_tier in result.prediction_set
202
+ for true_tier, result in zip(true_tiers, results)
203
+ ]
204
+ marginal_coverage = np.mean(covered)
205
+
206
+ # Average set size
207
+ set_sizes = [r.set_size for r in results]
208
+ avg_set_size = np.mean(set_sizes)
209
+
210
+ # Coverage by true label value
211
+ pos_mask = test_labels == 1
212
+ neg_mask = test_labels == 0
213
+ pos_coverage = np.mean([c for c, m in zip(covered, pos_mask) if m]) if pos_mask.sum() > 0 else 0.0
214
+ neg_coverage = np.mean([c for c, m in zip(covered, neg_mask) if m]) if neg_mask.sum() > 0 else 0.0
215
+
216
+ # Set size distribution
217
+ size_counts = {}
218
+ for s in set_sizes:
219
+ size_counts[s] = size_counts.get(s, 0) + 1
220
+
221
+ # Efficiency: lower set sizes = more informative
222
+ efficiency = 1.0 - (avg_set_size / len(self.TIERS))
223
+
224
+ # Interval width statistics
225
+ widths = [r.upper_bound - r.lower_bound for r in results]
226
+
227
+ metrics = {
228
+ "alpha": self.alpha,
229
+ "target_coverage": 1 - self.alpha,
230
+ "marginal_coverage": float(marginal_coverage),
231
+ "coverage_guarantee_met": bool(marginal_coverage >= (1 - self.alpha - 0.01)),
232
+ "avg_set_size": float(avg_set_size),
233
+ "efficiency": float(efficiency),
234
+ "positive_coverage": float(pos_coverage),
235
+ "negative_coverage": float(neg_coverage),
236
+ "set_size_distribution": {str(k): v for k, v in sorted(size_counts.items())},
237
+ "n_test": len(test_labels),
238
+ "mean_interval_width": float(np.mean(widths)),
239
+ "median_interval_width": float(np.median(widths)),
240
+ }
241
+
242
+ print(f"\n Conformal Prediction Evaluation (alpha={self.alpha}):")
243
+ print(f" Target coverage: {1 - self.alpha:.1%}")
244
+ print(f" Marginal coverage: {marginal_coverage:.1%} "
245
+ f"{'OK' if metrics['coverage_guarantee_met'] else 'VIOLATION'}")
246
+ print(f" Positive coverage: {pos_coverage:.1%}")
247
+ print(f" Negative coverage: {neg_coverage:.1%}")
248
+ print(f" Avg set size: {avg_set_size:.2f} / {len(self.TIERS)} tiers")
249
+ print(f" Efficiency: {efficiency:.1%}")
250
+ print(f" Mean interval: [{np.mean([r.lower_bound for r in results]):.3f}, "
251
+ f"{np.mean([r.upper_bound for r in results]):.3f}]")
252
+ print(f" Set size dist: {size_counts}")
253
+
254
+ return metrics
255
+
256
+ def save_state(self) -> dict:
257
+ """Serialize calibration state for checkpoint saving."""
258
+ if not self.is_calibrated:
259
+ return {"is_calibrated": False}
260
+ return {
261
+ "is_calibrated": True,
262
+ "alpha": self.alpha,
263
+ "q_hat": self.q_hat,
264
+ "q_residual": self.q_residual,
265
+ "n_cal": self.n_cal,
266
+ "tiers": {k: list(v) for k, v in self.TIERS.items()},
267
+ }
268
+
269
+ @classmethod
270
+ def from_state(cls, state: dict) -> "ConformalPredictor":
271
+ """Restore from serialized state."""
272
+ obj = cls()
273
+ if state.get("is_calibrated", False):
274
+ obj.alpha = state["alpha"]
275
+ obj.q_hat = state["q_hat"]
276
+ obj.q_residual = state["q_residual"]
277
+ obj.n_cal = state["n_cal"]
278
+ obj.is_calibrated = True
279
+ return obj
280
+
281
+
282
+ def run_conformal_at_multiple_levels(
283
+ cal_probs: np.ndarray,
284
+ cal_labels: np.ndarray,
285
+ test_probs: np.ndarray,
286
+ test_labels: np.ndarray,
287
+ alphas: list[float] = None,
288
+ ) -> dict:
289
+ """Run conformal prediction at multiple coverage levels.
290
+
291
+ Useful for reporting: "at 90% coverage, avg set size = X;
292
+ at 95%, avg set size = Y; at 99%, avg set size = Z"
293
+ """
294
+ if alphas is None:
295
+ alphas = [0.01, 0.05, 0.10, 0.20]
296
+
297
+ all_results = {}
298
+ for alpha in alphas:
299
+ cp = ConformalPredictor()
300
+ cp.calibrate(cal_probs, cal_labels, alpha=alpha)
301
+ eval_metrics = cp.evaluate(test_probs, test_labels)
302
+ all_results[f"alpha_{alpha}"] = {
303
+ "conformal_metrics": eval_metrics,
304
+ "conformal_state": cp.save_state(),
305
+ }
306
+
307
+ return all_results
src/evaluation/metrics.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by Claude Code -- 2026-02-08
2
+ """Evaluation metrics for conjunction prediction models."""
3
+
4
+ import numpy as np
5
+ from sklearn.metrics import (
6
+ average_precision_score,
7
+ roc_auc_score,
8
+ f1_score,
9
+ precision_recall_curve,
10
+ mean_absolute_error,
11
+ mean_squared_error,
12
+ classification_report,
13
+ )
14
+
15
+
16
+ def find_optimal_threshold(y_true: np.ndarray, y_prob: np.ndarray) -> tuple[float, float]:
17
+ """Find the threshold that maximizes F1 score on the precision-recall curve."""
18
+ precisions, recalls, thresholds = precision_recall_curve(y_true, y_prob)
19
+ # precision_recall_curve returns len(thresholds) = len(precisions) - 1
20
+ # Compute F1 for each threshold
21
+ f1_scores = 2 * (precisions[:-1] * recalls[:-1]) / (precisions[:-1] + recalls[:-1] + 1e-8)
22
+ best_idx = np.argmax(f1_scores)
23
+ return float(thresholds[best_idx]), float(f1_scores[best_idx])
24
+
25
+
26
+ def evaluate_risk(y_true: np.ndarray, y_prob: np.ndarray, threshold: float = 0.5) -> dict:
27
+ """
28
+ Evaluate risk classification predictions.
29
+
30
+ Args:
31
+ y_true: binary ground truth labels
32
+ y_prob: predicted probabilities
33
+ threshold: classification threshold (used for f1_at_50)
34
+
35
+ Returns: dict of metrics including optimal threshold F1
36
+ """
37
+ y_pred_fixed = (y_prob >= threshold).astype(int)
38
+
39
+ results = {
40
+ "auc_pr": float(average_precision_score(y_true, y_prob)) if y_true.sum() > 0 else 0.0,
41
+ "auc_roc": float(roc_auc_score(y_true, y_prob)) if len(np.unique(y_true)) > 1 else 0.0,
42
+ "f1_at_50": float(f1_score(y_true, y_pred_fixed, zero_division=0)),
43
+ "n_positive": int(y_true.sum()),
44
+ "n_total": int(len(y_true)),
45
+ "pos_rate": float(y_true.mean()),
46
+ }
47
+
48
+ # Find optimal threshold that maximizes F1
49
+ if y_true.sum() > 0:
50
+ opt_threshold, opt_f1 = find_optimal_threshold(y_true, y_prob)
51
+ results["f1"] = opt_f1
52
+ results["optimal_threshold"] = opt_threshold
53
+ results["threshold"] = opt_threshold
54
+ else:
55
+ results["f1"] = results["f1_at_50"]
56
+ results["optimal_threshold"] = threshold
57
+ results["threshold"] = threshold
58
+
59
+ # Recall at fixed precision levels
60
+ if y_true.sum() > 0:
61
+ precisions, recalls, thresholds = precision_recall_curve(y_true, y_prob)
62
+ for target_precision in [0.3, 0.5, 0.7]:
63
+ mask = precisions >= target_precision
64
+ if mask.any():
65
+ best_recall = recalls[mask].max()
66
+ results[f"recall_at_prec_{int(target_precision*100)}"] = float(best_recall)
67
+ else:
68
+ results[f"recall_at_prec_{int(target_precision*100)}"] = 0.0
69
+
70
+ return results
71
+
72
+
73
+ def evaluate_miss_distance(y_true_log: np.ndarray, y_pred_log: np.ndarray) -> dict:
74
+ """
75
+ Evaluate miss distance regression (log-scale).
76
+
77
+ Args:
78
+ y_true_log: log1p(miss_distance_km) ground truth
79
+ y_pred_log: log1p(miss_distance_km) predictions
80
+
81
+ Returns: dict of metrics
82
+ """
83
+ mae_log = float(mean_absolute_error(y_true_log, y_pred_log))
84
+ rmse_log = float(np.sqrt(mean_squared_error(y_true_log, y_pred_log)))
85
+
86
+ # Convert back to km for interpretable metrics
87
+ y_true_km = np.expm1(y_true_log)
88
+ y_pred_km = np.expm1(y_pred_log)
89
+ mae_km = float(mean_absolute_error(y_true_km, y_pred_km))
90
+
91
+ return {
92
+ "mae_log": mae_log,
93
+ "rmse_log": rmse_log,
94
+ "mae_km": mae_km,
95
+ "median_abs_error_km": float(np.median(np.abs(y_true_km - y_pred_km))),
96
+ }
97
+
98
+
99
+ def full_evaluation(
100
+ model_name: str,
101
+ y_risk_true: np.ndarray,
102
+ y_risk_prob: np.ndarray,
103
+ y_miss_true_log: np.ndarray,
104
+ y_miss_pred_log: np.ndarray,
105
+ ) -> dict:
106
+ """Run full evaluation suite for a model."""
107
+ risk_metrics = evaluate_risk(y_risk_true, y_risk_prob)
108
+ miss_metrics = evaluate_miss_distance(y_miss_true_log, y_miss_pred_log)
109
+
110
+ results = {"model": model_name, **risk_metrics, **miss_metrics}
111
+
112
+ print(f"\n{'='*60}")
113
+ print(f" {model_name}")
114
+ print(f"{'='*60}")
115
+ print(f" Risk Classification:")
116
+ print(f" AUC-PR: {risk_metrics['auc_pr']:.4f}")
117
+ print(f" AUC-ROC: {risk_metrics['auc_roc']:.4f}")
118
+ print(f" F1 (opt): {risk_metrics['f1']:.4f} (threshold={risk_metrics.get('optimal_threshold', 0.5):.3f})")
119
+ print(f" F1 (0.50): {risk_metrics['f1_at_50']:.4f}")
120
+ print(f" Positives: {risk_metrics['n_positive']}/{risk_metrics['n_total']} "
121
+ f"({risk_metrics['pos_rate']:.1%})")
122
+ print(f" Miss Distance:")
123
+ print(f" MAE (log): {miss_metrics['mae_log']:.4f}")
124
+ print(f" MAE (km): {miss_metrics['mae_km']:.2f}")
125
+ print(f" Median AE: {miss_metrics['median_abs_error_km']:.2f} km")
126
+ print(f"{'='*60}")
127
+
128
+ return results
src/evaluation/staleness.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by Claude Code -- 2026-02-13
2
+ """TLE Staleness Sensitivity Experiment.
3
+
4
+ Evaluates how model performance degrades as CDM data becomes stale.
5
+ Simulates staleness by filtering CDM sequences to only include updates
6
+ received at least `cutoff_days` before TCA.
7
+
8
+ The Kelvins test set has time_to_tca in [2.0, 7.0] days, so meaningful
9
+ cutoffs are in that range. A cutoff of 2.0 keeps all data (baseline),
10
+ while a cutoff of 6.0 keeps only the earliest CDMs.
11
+
12
+ Ground-truth labels always come from the ORIGINAL (untruncated) test set —
13
+ we're measuring how well models predict with less-recent information.
14
+ """
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ import torch
19
+ from torch.utils.data import DataLoader
20
+
21
+ from src.data.cdm_loader import build_events, events_to_flat_features, get_feature_columns
22
+ from src.data.sequence_builder import CDMSequenceDataset
23
+ from src.evaluation.metrics import evaluate_risk
24
+
25
+ # Staleness cutoffs (days before TCA)
26
+ # 2.0 = keep all data (baseline), 6.0 = only very early CDMs
27
+ DEFAULT_CUTOFFS = [2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 6.0]
28
+ QUICK_CUTOFFS = [2.0, 4.0, 6.0]
29
+
30
+
31
+ def truncate_cdm_dataframe(df: pd.DataFrame, cutoff_days: float) -> pd.DataFrame:
32
+ """Filter CDM rows to only those with time_to_tca >= cutoff_days.
33
+
34
+ Simulates data staleness: if cutoff=4.0, the model only sees CDMs
35
+ that arrived 4+ days before closest approach.
36
+ """
37
+ return df[df["time_to_tca"] >= cutoff_days].copy()
38
+
39
+
40
+ def get_ground_truth_labels(df: pd.DataFrame) -> dict:
41
+ """Extract per-event ground truth labels from the FULL (untruncated) dataset.
42
+
43
+ Labels come from the final CDM per event (closest to TCA).
44
+ Returns: {event_id: {"risk_label": int, "miss_log": float, "altitude_km": float}}
45
+ """
46
+ labels = {}
47
+ for event_id, group in df.groupby("event_id"):
48
+ group = group.sort_values("time_to_tca", ascending=True)
49
+ final = group.iloc[0]
50
+ risk_label = 1 if final["risk"] > -5 else 0
51
+ miss_log = float(np.log1p(max(final.get("miss_distance", 0.0), 0.0)))
52
+ alt = float(final.get("t_h_apo", 0.0))
53
+ labels[int(event_id)] = {
54
+ "risk_label": risk_label,
55
+ "miss_log": miss_log,
56
+ "altitude_km": alt,
57
+ }
58
+ return labels
59
+
60
+
61
+ def evaluate_baseline_at_cutoff(baseline_model, ground_truth: dict, cutoff: float) -> dict:
62
+ """Evaluate baseline model. Uses altitude only, unaffected by staleness."""
63
+ altitudes = np.array([gt["altitude_km"] for gt in ground_truth.values()])
64
+ y_true = np.array([gt["risk_label"] for gt in ground_truth.values()])
65
+ risk_probs, _ = baseline_model.predict(altitudes)
66
+ metrics = evaluate_risk(y_true, risk_probs)
67
+ metrics["cutoff"] = cutoff
68
+ metrics["n_events"] = len(y_true)
69
+ return metrics
70
+
71
+
72
+ def evaluate_xgboost_at_cutoff(
73
+ xgboost_model,
74
+ truncated_df: pd.DataFrame,
75
+ ground_truth: dict,
76
+ feature_cols: list[str],
77
+ cutoff: float,
78
+ ) -> dict:
79
+ """Evaluate XGBoost on truncated CDM data."""
80
+ events = build_events(truncated_df, feature_cols)
81
+ if len(events) == 0:
82
+ return {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff}
83
+
84
+ X, _, _ = events_to_flat_features(events)
85
+
86
+ # Pad features if model was trained on augmented data with more columns
87
+ expected_features = xgboost_model.scaler.n_features_in_
88
+ if X.shape[1] < expected_features:
89
+ padding = np.zeros((X.shape[0], expected_features - X.shape[1]), dtype=X.dtype)
90
+ X = np.hstack([X, padding])
91
+
92
+ event_ids = [e.event_id for e in events]
93
+ valid_mask = np.array([eid in ground_truth for eid in event_ids])
94
+ X = X[valid_mask]
95
+ valid_ids = [eid for eid in event_ids if eid in ground_truth]
96
+ y_true = np.array([ground_truth[eid]["risk_label"] for eid in valid_ids])
97
+
98
+ if len(y_true) == 0 or y_true.sum() == 0:
99
+ return {"auc_pr": 0.0, "f1": 0.0, "n_events": len(y_true), "cutoff": cutoff}
100
+
101
+ # Pad features if model expects more (e.g., trained on augmented data)
102
+ expected = xgboost_model.scaler.n_features_in_
103
+ if X.shape[1] < expected:
104
+ pad_width = expected - X.shape[1]
105
+ X = np.pad(X, ((0, 0), (0, pad_width)), constant_values=0)
106
+ elif X.shape[1] > expected:
107
+ X = X[:, :expected]
108
+
109
+ risk_probs = xgboost_model.predict_risk(X)
110
+ metrics = evaluate_risk(y_true, risk_probs)
111
+ metrics["cutoff"] = cutoff
112
+ metrics["n_events"] = len(y_true)
113
+ return metrics
114
+
115
+
116
+ def evaluate_pitft_at_cutoff(
117
+ model,
118
+ truncated_df: pd.DataFrame,
119
+ ground_truth: dict,
120
+ train_ds: CDMSequenceDataset,
121
+ device: torch.device,
122
+ temperature: float = 1.0,
123
+ cutoff: float = 0.0,
124
+ batch_size: int = 128,
125
+ ) -> dict:
126
+ """Evaluate PI-TFT on truncated CDM data with temperature scaling."""
127
+ # Ensure all required columns exist (pad missing with 0)
128
+ df = truncated_df.copy()
129
+ for col in train_ds.temporal_cols + train_ds.static_cols:
130
+ if col not in df.columns:
131
+ df[col] = 0.0
132
+
133
+ test_ds = CDMSequenceDataset(
134
+ df,
135
+ temporal_cols=train_ds.temporal_cols,
136
+ static_cols=train_ds.static_cols,
137
+ )
138
+ test_ds.set_normalization(train_ds)
139
+
140
+ if len(test_ds) == 0:
141
+ return {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff}
142
+
143
+ # Get event IDs from the dataset
144
+ event_ids = [e["event_id"] for e in test_ds.events]
145
+
146
+ loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0)
147
+
148
+ model.eval()
149
+ all_probs = []
150
+
151
+ with torch.no_grad():
152
+ for batch in loader:
153
+ temporal = batch["temporal"].to(device)
154
+ static = batch["static"].to(device)
155
+ tca = batch["time_to_tca"].to(device)
156
+ mask = batch["mask"].to(device)
157
+
158
+ risk_logit, _, _, _ = model(temporal, static, tca, mask)
159
+ probs = torch.sigmoid(risk_logit / temperature).cpu().numpy().flatten()
160
+ all_probs.append(probs)
161
+
162
+ risk_probs = np.concatenate(all_probs)
163
+
164
+ # Match predictions to ground truth
165
+ valid_mask = np.array([eid in ground_truth for eid in event_ids])
166
+ risk_probs = risk_probs[valid_mask]
167
+ valid_ids = [eid for eid in event_ids if eid in ground_truth]
168
+ y_true = np.array([ground_truth[eid]["risk_label"] for eid in valid_ids])
169
+
170
+ if len(y_true) == 0 or y_true.sum() == 0:
171
+ return {"auc_pr": 0.0, "f1": 0.0, "n_events": len(y_true), "cutoff": cutoff}
172
+
173
+ metrics = evaluate_risk(y_true, risk_probs)
174
+ metrics["cutoff"] = cutoff
175
+ metrics["n_events"] = int(len(y_true))
176
+ return metrics
177
+
178
+
179
+ def run_staleness_experiment(
180
+ baseline_model,
181
+ xgboost_model,
182
+ pitft_model,
183
+ pitft_checkpoint: dict,
184
+ test_df: pd.DataFrame,
185
+ train_ds: CDMSequenceDataset,
186
+ feature_cols: list[str],
187
+ device: torch.device,
188
+ cutoffs: list[float] = None,
189
+ quick: bool = False,
190
+ ) -> dict:
191
+ """Run the full staleness experiment across all cutoffs and models.
192
+
193
+ Args:
194
+ baseline_model: OrbitalShellBaseline instance
195
+ xgboost_model: XGBoostConjunctionModel instance
196
+ pitft_model: PhysicsInformedTFT (eval mode), or None to skip
197
+ pitft_checkpoint: checkpoint dict with temperature
198
+ test_df: ORIGINAL (untruncated) test DataFrame
199
+ train_ds: CDMSequenceDataset from training data (for normalization)
200
+ feature_cols: list of feature column names for XGBoost
201
+ device: torch device
202
+ cutoffs: list of staleness cutoffs (days before TCA)
203
+ quick: if True, use fewer cutoffs
204
+ """
205
+ if cutoffs is None:
206
+ cutoffs = QUICK_CUTOFFS if quick else DEFAULT_CUTOFFS
207
+
208
+ ground_truth = get_ground_truth_labels(test_df)
209
+ n_pos = sum(1 for gt in ground_truth.values() if gt["risk_label"] == 1)
210
+ print(f"\nGround truth: {len(ground_truth)} events, {n_pos} positive")
211
+
212
+ temperature = 1.0
213
+ if pitft_checkpoint:
214
+ temperature = pitft_checkpoint.get("temperature", 1.0)
215
+
216
+ results = {
217
+ "cutoffs": cutoffs,
218
+ "n_test_events": len(ground_truth),
219
+ "n_positive": n_pos,
220
+ "baseline": [],
221
+ "xgboost": [],
222
+ "pitft": [],
223
+ }
224
+
225
+ for cutoff in cutoffs:
226
+ print(f"\n{'='*50}")
227
+ print(f"Staleness cutoff: {cutoff:.1f} days")
228
+ print(f"{'='*50}")
229
+
230
+ truncated = truncate_cdm_dataframe(test_df, cutoff)
231
+ n_events = truncated["event_id"].nunique()
232
+ n_rows = len(truncated)
233
+ print(f" Surviving: {n_events} events, {n_rows} CDMs")
234
+
235
+ # Baseline (uses altitude only — constant across cutoffs)
236
+ bl = evaluate_baseline_at_cutoff(baseline_model, ground_truth, cutoff)
237
+ results["baseline"].append(bl)
238
+ print(f" Baseline AUC-PR={bl.get('auc_pr', 0):.4f}, F1={bl.get('f1', 0):.4f}")
239
+
240
+ # XGBoost
241
+ if n_events > 0:
242
+ xgb = evaluate_xgboost_at_cutoff(
243
+ xgboost_model, truncated, ground_truth, feature_cols, cutoff
244
+ )
245
+ else:
246
+ xgb = {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff}
247
+ results["xgboost"].append(xgb)
248
+ print(f" XGBoost AUC-PR={xgb.get('auc_pr', 0):.4f}, "
249
+ f"F1={xgb.get('f1', 0):.4f} ({xgb.get('n_events', 0)} events)")
250
+
251
+ # PI-TFT
252
+ if n_events > 0 and pitft_model is not None:
253
+ tft = evaluate_pitft_at_cutoff(
254
+ pitft_model, truncated, ground_truth, train_ds,
255
+ device, temperature=temperature, cutoff=cutoff,
256
+ )
257
+ else:
258
+ tft = {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff}
259
+ results["pitft"].append(tft)
260
+ print(f" PI-TFT AUC-PR={tft.get('auc_pr', 0):.4f}, "
261
+ f"F1={tft.get('f1', 0):.4f}")
262
+
263
+ return results
src/model/__init__.py ADDED
File without changes
src/model/baseline.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by Claude Code -- 2026-02-08
2
+ """Model 1: Naive Baseline -- Orbital Shell Density Prior.
3
+
4
+ Predicts collision risk based solely on the altitude band of the conjunction,
5
+ using historical base rates. This establishes that altitude alone is predictive
6
+ (LEO is more crowded) but insufficient for actionable conjunction assessment.
7
+ """
8
+
9
+ import json
10
+ import numpy as np
11
+ from pathlib import Path
12
+ from collections import defaultdict
13
+
14
+
15
+ class OrbitalShellBaseline:
16
+ """
17
+ Altitude-band collision rate baseline.
18
+
19
+ For any conjunction event, predict the average risk and miss distance
20
+ for that altitude regime. Bins events into 50km altitude bands.
21
+ """
22
+
23
+ def __init__(self, bin_width_km: float = 50.0):
24
+ self.bin_width = bin_width_km
25
+ self.bins: dict[int, dict] = {}
26
+ self.global_stats: dict = {}
27
+
28
+ def _altitude_to_bin(self, alt_km: float) -> int:
29
+ return int(round(alt_km / self.bin_width) * self.bin_width)
30
+
31
+ def fit(self, altitudes: np.ndarray, y_risk: np.ndarray, y_miss_log: np.ndarray):
32
+ """
33
+ Fit baseline from altitude array and labels.
34
+
35
+ Args:
36
+ altitudes: altitude in km for each event
37
+ y_risk: binary risk labels
38
+ y_miss_log: log1p(miss_distance_km) targets
39
+ """
40
+ # Global fallback stats
41
+ self.global_stats = {
42
+ "mean_risk": float(np.mean(y_risk)),
43
+ "mean_miss_log": float(np.mean(y_miss_log)),
44
+ "count": int(len(y_risk)),
45
+ }
46
+
47
+ # Per-bin statistics
48
+ bin_data = defaultdict(lambda: {"risks": [], "misses": []})
49
+
50
+ for alt, risk, miss in zip(altitudes, y_risk, y_miss_log):
51
+ b = self._altitude_to_bin(alt)
52
+ bin_data[b]["risks"].append(risk)
53
+ bin_data[b]["misses"].append(miss)
54
+
55
+ self.bins = {}
56
+ for b, data in bin_data.items():
57
+ self.bins[b] = {
58
+ "mean_risk": float(np.mean(data["risks"])),
59
+ "mean_miss_log": float(np.mean(data["misses"])),
60
+ "count": len(data["risks"]),
61
+ "risk_rate": float(np.sum(data["risks"]) / len(data["risks"])),
62
+ }
63
+
64
+ print(f"Baseline fit: {len(self.bins)} altitude bins, "
65
+ f"global risk rate = {self.global_stats['mean_risk']:.4f}")
66
+
67
+ def predict(self, altitudes: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
68
+ """
69
+ Predict risk probability and log miss distance for each altitude.
70
+
71
+ Returns: (risk_probs, miss_log_preds)
72
+ """
73
+ risk_preds = []
74
+ miss_preds = []
75
+
76
+ for alt in altitudes:
77
+ b = self._altitude_to_bin(alt)
78
+ if b in self.bins:
79
+ risk_preds.append(self.bins[b]["risk_rate"])
80
+ miss_preds.append(self.bins[b]["mean_miss_log"])
81
+ else:
82
+ risk_preds.append(self.global_stats["mean_risk"])
83
+ miss_preds.append(self.global_stats["mean_miss_log"])
84
+
85
+ return np.array(risk_preds), np.array(miss_preds)
86
+
87
+ def save(self, path: Path):
88
+ """Save model to JSON."""
89
+ data = {
90
+ "bin_width": self.bin_width,
91
+ "bins": {str(k): v for k, v in self.bins.items()},
92
+ "global_stats": self.global_stats,
93
+ }
94
+ path.parent.mkdir(parents=True, exist_ok=True)
95
+ with open(path, "w") as f:
96
+ json.dump(data, f, indent=2)
97
+ print(f"Baseline saved to {path}")
98
+
99
+ @classmethod
100
+ def load(cls, path: Path) -> "OrbitalShellBaseline":
101
+ """Load model from JSON."""
102
+ with open(path) as f:
103
+ data = json.load(f)
104
+ model = cls(bin_width_km=data["bin_width"])
105
+ model.bins = {int(k): v for k, v in data["bins"].items()}
106
+ model.global_stats = data["global_stats"]
107
+ return model
src/model/classical.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by Claude Code -- 2026-02-08
2
+ """Model 2: Classical ML -- XGBoost on engineered CDM features.
3
+
4
+ Dual-head model:
5
+ - Risk classifier (binary: high-risk vs safe)
6
+ - Miss distance regressor (log-scale km)
7
+ """
8
+
9
+ import pickle
10
+ import numpy as np
11
+ from pathlib import Path
12
+ from xgboost import XGBClassifier, XGBRegressor
13
+ from sklearn.preprocessing import StandardScaler
14
+
15
+
16
+ class XGBoostConjunctionModel:
17
+ """XGBoost with engineered CDM features."""
18
+
19
+ def __init__(self):
20
+ self.scaler = StandardScaler()
21
+
22
+ self.risk_classifier = XGBClassifier(
23
+ n_estimators=500,
24
+ max_depth=8,
25
+ learning_rate=0.05,
26
+ scale_pos_weight=50, # severe class imbalance
27
+ eval_metric="aucpr",
28
+ tree_method="hist",
29
+ random_state=42,
30
+ )
31
+
32
+ self.miss_regressor = XGBRegressor(
33
+ n_estimators=500,
34
+ max_depth=8,
35
+ learning_rate=0.05,
36
+ objective="reg:squaredlogerror",
37
+ tree_method="hist",
38
+ random_state=42,
39
+ )
40
+
41
+ def fit(
42
+ self,
43
+ X_train: np.ndarray,
44
+ y_risk: np.ndarray,
45
+ y_miss_log: np.ndarray,
46
+ X_val: np.ndarray = None,
47
+ y_risk_val: np.ndarray = None,
48
+ y_miss_val: np.ndarray = None,
49
+ ):
50
+ """Train both heads."""
51
+ # Scale features
52
+ X_scaled = self.scaler.fit_transform(X_train)
53
+
54
+ # Risk classifier
55
+ print(f"Training risk classifier (pos_rate={y_risk.mean():.4f}) ...")
56
+ eval_set = None
57
+ if X_val is not None:
58
+ eval_set = [(self.scaler.transform(X_val), y_risk_val)]
59
+ self.risk_classifier.fit(
60
+ X_scaled, y_risk,
61
+ eval_set=eval_set,
62
+ verbose=50,
63
+ )
64
+
65
+ # Miss distance regressor (log-scale, must be > 0 for squaredlogerror)
66
+ y_miss_positive = np.clip(y_miss_log, 1e-6, None)
67
+ print("Training miss distance regressor ...")
68
+ eval_set_miss = None
69
+ if X_val is not None:
70
+ y_miss_val_pos = np.clip(y_miss_val, 1e-6, None)
71
+ eval_set_miss = [(self.scaler.transform(X_val), y_miss_val_pos)]
72
+ self.miss_regressor.fit(
73
+ X_scaled, y_miss_positive,
74
+ eval_set=eval_set_miss,
75
+ verbose=50,
76
+ )
77
+
78
+ def predict(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
79
+ """
80
+ Predict risk probability and miss distance.
81
+
82
+ Returns: (risk_probs, miss_distance_km)
83
+ """
84
+ X_scaled = self.scaler.transform(X)
85
+ risk_probs = self.risk_classifier.predict_proba(X_scaled)[:, 1]
86
+ miss_log = self.miss_regressor.predict(X_scaled)
87
+ miss_km = np.expm1(miss_log)
88
+ return risk_probs, miss_km
89
+
90
+ def predict_risk(self, X: np.ndarray) -> np.ndarray:
91
+ """Predict risk probability only."""
92
+ X_scaled = self.scaler.transform(X)
93
+ return self.risk_classifier.predict_proba(X_scaled)[:, 1]
94
+
95
+ def save(self, path: Path):
96
+ """Save all components."""
97
+ path.parent.mkdir(parents=True, exist_ok=True)
98
+ with open(path, "wb") as f:
99
+ pickle.dump({
100
+ "scaler": self.scaler,
101
+ "risk_classifier": self.risk_classifier,
102
+ "miss_regressor": self.miss_regressor,
103
+ }, f)
104
+ print(f"XGBoost model saved to {path}")
105
+
106
+ @classmethod
107
+ def load(cls, path: Path) -> "XGBoostConjunctionModel":
108
+ """Load all components."""
109
+ with open(path, "rb") as f:
110
+ data = pickle.load(f)
111
+ model = cls()
112
+ model.scaler = data["scaler"]
113
+ model.risk_classifier = data["risk_classifier"]
114
+ model.miss_regressor = data["miss_regressor"]
115
+ return model
src/model/deep.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by Claude Code -- 2026-02-08
2
+ """Model 3: Physics-Informed Temporal Fusion Transformer (PI-TFT).
3
+
4
+ Architecture overview (think of it like reading serial lab values):
5
+
6
+ 1. VARIABLE SELECTION: Not all 22 CDM features matter equally. The model
7
+ learns attention weights over features -- e.g., miss_distance and
8
+ covariance shrinkage rate might matter more than raw orbital elements.
9
+ This is like a doctor learning which labs to focus on.
10
+
11
+ 2. STATIC CONTEXT: Object properties (altitude, size, eccentricity) don't
12
+ change between CDM updates. They're encoded once and injected as context
13
+ into the temporal processing. Like knowing the patient's age and history.
14
+
15
+ 3. CONTINUOUS TIME EMBEDDING: CDMs arrive at irregular intervals (not evenly
16
+ spaced). Instead of positional encoding (position 1, 2, 3...), we embed
17
+ the actual time_to_tca value. The model knows "this CDM was 3.2 days
18
+ before closest approach" vs "this one was 0.5 days before."
19
+
20
+ 4. TEMPORAL SELF-ATTENTION: The Transformer reads the full CDM sequence and
21
+ learns which updates were most informative. A sudden miss distance drop
22
+ at day -2 gets more attention than a stable reading at day -5.
23
+
24
+ 5. PREDICTION HEADS: The final hidden state (from the most recent CDM)
25
+ feeds into two prediction heads:
26
+ - Risk classifier: sigmoid probability of high-risk collision
27
+ - Miss distance regressor: predicted log(miss distance in km)
28
+
29
+ 6. PHYSICS LOSS: The training loss includes a penalty when the model predicts
30
+ a miss distance BELOW the Minimum Orbital Intersection Distance (MOID).
31
+ MOID is the closest the two orbits can geometrically get. Predicting
32
+ closer than MOID is physically impossible (without a maneuver), so we
33
+ penalize it. This is like penalizing a model for predicting negative
34
+ blood pressure -- constraining outputs to the physically possible range.
35
+ """
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+ import torch.nn.functional as F
40
+ import math
41
+
42
+
43
+ class GatedResidualNetwork(nn.Module):
44
+ """
45
+ Gated skip connection with ELU activation and layer norm.
46
+
47
+ Think of this as a "smart residual block" -- it learns how much of the
48
+ transformed input to mix with the original. The gate (sigmoid) controls
49
+ this: gate=0 means pass through unchanged, gate=1 means fully transformed.
50
+ """
51
+
52
+ def __init__(self, d_model: int, d_hidden: int = None, dropout: float = 0.1):
53
+ super().__init__()
54
+ d_hidden = d_hidden or d_model
55
+ self.fc1 = nn.Linear(d_model, d_hidden)
56
+ self.fc2 = nn.Linear(d_hidden, d_model)
57
+ self.gate_fc = nn.Linear(d_hidden, d_model)
58
+ self.norm = nn.LayerNorm(d_model)
59
+ self.dropout = nn.Dropout(dropout)
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ residual = x
63
+ h = F.elu(self.fc1(x))
64
+ h = self.dropout(h)
65
+ transform = self.fc2(h)
66
+ gate = torch.sigmoid(self.gate_fc(h))
67
+ return self.norm(residual + gate * transform)
68
+
69
+
70
+ class VariableSelectionNetwork(nn.Module):
71
+ """
72
+ Learns which input features matter most via softmax attention.
73
+
74
+ For N input features, produces N attention weights that sum to 1.
75
+ Each feature is independently projected to d_model, then weighted
76
+ and summed. The weights are interpretable -- they tell you which
77
+ CDM columns the model found most predictive.
78
+ """
79
+
80
+ def __init__(self, n_features: int, d_model: int, dropout: float = 0.1):
81
+ super().__init__()
82
+ self.n_features = n_features
83
+ self.d_model = d_model
84
+
85
+ # Each feature gets its own linear projection: scalar -> d_model vector
86
+ self.feature_projections = nn.ModuleList([
87
+ nn.Linear(1, d_model) for _ in range(n_features)
88
+ ])
89
+
90
+ # Gating network: takes flattened projections -> feature weights
91
+ self.gate_network = nn.Sequential(
92
+ nn.Linear(n_features * d_model, n_features),
93
+ nn.Softmax(dim=-1),
94
+ )
95
+
96
+ self.grn = GatedResidualNetwork(d_model, dropout=dropout)
97
+
98
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
99
+ """
100
+ Args:
101
+ x: (..., n_features) — can be (B, F) for static or (B, S, F) for temporal
102
+
103
+ Returns:
104
+ output: (..., d_model) — weighted combination of projected features
105
+ weights: (..., n_features) — attention weights (sum to 1)
106
+ """
107
+ # Project each feature independently
108
+ # x[..., i:i+1] is the i-th feature, shape (..., 1)
109
+ projected = [proj(x[..., i:i+1]) for i, proj in enumerate(self.feature_projections)]
110
+ # projected[i] shape: (..., d_model)
111
+
112
+ # Stack for gating: (..., n_features, d_model)
113
+ stacked = torch.stack(projected, dim=-2)
114
+
115
+ # Flatten for gate computation: (..., n_features * d_model)
116
+ flat = stacked.reshape(*stacked.shape[:-2], -1)
117
+ weights = self.gate_network(flat) # (..., n_features)
118
+
119
+ # Weighted sum: (..., d_model)
120
+ output = (stacked * weights.unsqueeze(-1)).sum(dim=-2)
121
+ output = self.grn(output)
122
+
123
+ return output, weights
124
+
125
+
126
+ class PhysicsInformedTFT(nn.Module):
127
+ """
128
+ Physics-Informed Temporal Fusion Transformer for conjunction assessment.
129
+
130
+ Input flow:
131
+ temporal_features (B, S, F_t) → Variable Selection → time embedding → self-attention → attention pool → heads
132
+ static_features (B, F_s) → Variable Selection → context injection ↗
133
+
134
+ Output:
135
+ risk_logit: (B, 1) — raw logit for risk classification (apply sigmoid for probability)
136
+ miss_log: (B, 1) — predicted log1p(miss_distance_km)
137
+ pc_log10: (B, 1) — predicted log10(Pc) collision probability (when has_pc_head=True)
138
+ feature_weights: (B, S, F_t) — which temporal features mattered
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ n_temporal_features: int,
144
+ n_static_features: int,
145
+ d_model: int = 128,
146
+ n_heads: int = 4,
147
+ n_layers: int = 2,
148
+ dropout: float = 0.15,
149
+ max_seq_len: int = 30,
150
+ ):
151
+ super().__init__()
152
+ self.d_model = d_model
153
+ self.max_seq_len = max_seq_len
154
+
155
+ # --- Variable Selection Networks ---
156
+ self.temporal_vsn = VariableSelectionNetwork(n_temporal_features, d_model, dropout)
157
+ self.static_vsn = VariableSelectionNetwork(n_static_features, d_model, dropout)
158
+
159
+ # --- Static context encoding ---
160
+ self.static_encoder = nn.Sequential(
161
+ nn.Linear(d_model, d_model),
162
+ nn.GELU(),
163
+ nn.Dropout(dropout),
164
+ )
165
+ # Static -> enrichment vector that's added to each temporal step
166
+ self.static_to_enrichment = nn.Linear(d_model, d_model)
167
+
168
+ # --- Continuous time embedding ---
169
+ # Instead of fixed positional encoding, we embed the actual time_to_tca
170
+ self.time_embedding = nn.Sequential(
171
+ nn.Linear(1, d_model // 2),
172
+ nn.GELU(),
173
+ nn.Linear(d_model // 2, d_model),
174
+ )
175
+
176
+ # --- Transformer encoder layers ---
177
+ encoder_layer = nn.TransformerEncoderLayer(
178
+ d_model=d_model,
179
+ nhead=n_heads,
180
+ dim_feedforward=d_model * 2,
181
+ dropout=dropout,
182
+ activation="gelu",
183
+ batch_first=True,
184
+ norm_first=True,
185
+ )
186
+ self.transformer_encoder = nn.TransformerEncoder(
187
+ encoder_layer, num_layers=n_layers
188
+ )
189
+
190
+ # --- Pre/post attention processing ---
191
+ self.pre_attn_grn = GatedResidualNetwork(d_model, dropout=dropout)
192
+ self.post_attn_grn = GatedResidualNetwork(d_model, dropout=dropout)
193
+
194
+ # --- Attention-weighted pooling ---
195
+ # Learns which time steps matter most instead of just taking the last one.
196
+ # Softmax attention over all real positions, with padding masked out.
197
+ self.pool_attention = nn.Sequential(
198
+ nn.Linear(d_model, d_model // 2),
199
+ nn.Tanh(),
200
+ nn.Linear(d_model // 2, 1),
201
+ )
202
+
203
+ # --- Prediction heads ---
204
+ self.risk_head = nn.Sequential(
205
+ nn.LayerNorm(d_model),
206
+ nn.Linear(d_model, 64),
207
+ nn.GELU(),
208
+ nn.Dropout(dropout),
209
+ nn.Linear(64, 1),
210
+ )
211
+
212
+ self.miss_head = nn.Sequential(
213
+ nn.LayerNorm(d_model),
214
+ nn.Linear(d_model, 64),
215
+ nn.GELU(),
216
+ nn.Dropout(dropout),
217
+ nn.Linear(64, 1),
218
+ )
219
+
220
+ # --- Collision probability head ---
221
+ # Predicts log10(Pc) directly instead of binary risk classification.
222
+ # Pc ranges from ~1e-20 to ~1e-1, so log10 scale maps to [-20, -1].
223
+ # The Kelvins `risk` column is already log10(Pc).
224
+ self.pc_head = nn.Sequential(
225
+ nn.LayerNorm(d_model),
226
+ nn.Linear(d_model, 64),
227
+ nn.GELU(),
228
+ nn.Dropout(dropout),
229
+ nn.Linear(64, 1),
230
+ )
231
+
232
+ def encode_sequence(
233
+ self,
234
+ temporal_features: torch.Tensor, # (B, S, F_t)
235
+ static_features: torch.Tensor, # (B, F_s)
236
+ time_to_tca: torch.Tensor, # (B, S, 1)
237
+ mask: torch.Tensor, # (B, S) — True for real, False for padding
238
+ ):
239
+ """Encode CDM sequence into per-timestep hidden states.
240
+
241
+ Returns:
242
+ hidden: (B, S, D) per-timestep representations after Transformer
243
+ temporal_weights: (B, S, F_t) variable selection attention weights
244
+ """
245
+ # 1. Variable selection -- learn which features matter
246
+ temporal_selected, temporal_weights = self.temporal_vsn(temporal_features)
247
+ # temporal_selected: (B, S, D), temporal_weights: (B, S, F_t)
248
+
249
+ static_selected, static_weights = self.static_vsn(static_features)
250
+ # static_selected: (B, D)
251
+
252
+ # 2. Static context -- compute enrichment vector
253
+ static_ctx = self.static_encoder(static_selected) # (B, D)
254
+ enrichment = self.static_to_enrichment(static_ctx) # (B, D)
255
+
256
+ # 3. Continuous time embedding
257
+ t_embed = self.time_embedding(time_to_tca) # (B, S, D)
258
+
259
+ # 4. Combine: temporal + time + static context
260
+ x = temporal_selected + t_embed + enrichment.unsqueeze(1)
261
+
262
+ # 5. Pre-attention GRN
263
+ x = self.pre_attn_grn(x)
264
+
265
+ # 6. Transformer self-attention
266
+ # Convert mask: True=real -> need to invert for PyTorch's src_key_padding_mask
267
+ # PyTorch expects True=ignore, so we flip
268
+ padding_mask = ~mask # (B, S), True = pad position to ignore
269
+ x = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
270
+
271
+ # 7. Post-attention GRN
272
+ x = self.post_attn_grn(x)
273
+
274
+ return x, temporal_weights
275
+
276
+ def forward(
277
+ self,
278
+ temporal_features: torch.Tensor, # (B, S, F_t)
279
+ static_features: torch.Tensor, # (B, F_s)
280
+ time_to_tca: torch.Tensor, # (B, S, 1)
281
+ mask: torch.Tensor, # (B, S) — True for real, False for padding
282
+ ):
283
+ B, S, _ = temporal_features.shape
284
+
285
+ # Steps 1-7: encode sequence into per-timestep hidden states
286
+ x, temporal_weights = self.encode_sequence(
287
+ temporal_features, static_features, time_to_tca, mask
288
+ )
289
+
290
+ # 8. Attention-weighted pooling over all real positions
291
+ # Instead of just the last CDM, learn which time steps matter most
292
+ attn_scores = self.pool_attention(x).squeeze(-1) # (B, S)
293
+ # Mask padding positions with -inf so they get zero attention
294
+ attn_scores = attn_scores.masked_fill(~mask, float("-inf"))
295
+ attn_weights = F.softmax(attn_scores, dim=-1) # (B, S)
296
+ # Handle all-padding edge case (shouldn't happen but be safe)
297
+ attn_weights = attn_weights.nan_to_num(0.0)
298
+ x_pooled = (x * attn_weights.unsqueeze(-1)).sum(dim=1) # (B, D)
299
+
300
+ # 9. Prediction heads
301
+ risk_logit = self.risk_head(x_pooled) # (B, 1)
302
+ miss_log = self.miss_head(x_pooled) # (B, 1)
303
+ pc_log10 = self.pc_head(x_pooled) # (B, 1) — log10(Pc)
304
+
305
+ return risk_logit, miss_log, pc_log10, temporal_weights
306
+
307
+ def count_parameters(self) -> int:
308
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
309
+
310
+
311
+ class SigmoidFocalLoss(nn.Module):
312
+ """
313
+ Focal Loss for binary classification (Lin et al., 2017).
314
+
315
+ Down-weights well-classified examples so the model focuses on hard cases.
316
+ FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)
317
+
318
+ With gamma=0, this reduces to standard weighted BCE.
319
+ With gamma=2, easy examples (p_t > 0.9) get ~100x less weight.
320
+ """
321
+
322
+ def __init__(self, alpha: float = 0.75, gamma: float = 2.0, reduction: str = "mean"):
323
+ super().__init__()
324
+ self.alpha = alpha
325
+ self.gamma = gamma
326
+ self.reduction = reduction
327
+
328
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
329
+ p = torch.sigmoid(logits)
330
+ # p_t = probability of the true class
331
+ p_t = targets * p + (1 - targets) * (1 - p)
332
+ # alpha_t = alpha for positive class, (1-alpha) for negative
333
+ alpha_t = targets * self.alpha + (1 - targets) * (1 - self.alpha)
334
+ # focal modulator: (1 - p_t)^gamma
335
+ focal_weight = (1 - p_t) ** self.gamma
336
+ # BCE per-element (numerically stable via log-sum-exp)
337
+ bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
338
+ loss = alpha_t * focal_weight * bce
339
+ if self.reduction == "none":
340
+ return loss
341
+ return loss.mean()
342
+
343
+
344
+ class PhysicsInformedLoss(nn.Module):
345
+ """
346
+ Combined task loss + physics regularization.
347
+
348
+ Total loss = risk_weight * FocalLoss(risk) + miss_weight * MSE(miss_distance)
349
+ + pc_weight * MSE(log10_Pc) + physics_weight * ReLU(MOID - predicted_miss)
350
+
351
+ The physics term: MOID (Minimum Orbital Intersection Distance) is the
352
+ geometric minimum distance between two orbits. The actual miss distance
353
+ at closest approach CANNOT be less than MOID (without a maneuver).
354
+ If the model predicts miss < MOID, we penalize it.
355
+
356
+ The Pc term: direct regression on log10(collision probability). The Kelvins
357
+ `risk` column is log10(Pc), giving us 162K labeled examples. This lets
358
+ the model output calibrated collision probabilities, not just binary risk.
359
+
360
+ For the Kelvins dataset, we approximate MOID from the orbital elements
361
+ in the CDM features. When MOID isn't available, the physics term is 0.
362
+ """
363
+
364
+ def __init__(
365
+ self,
366
+ risk_weight: float = 1.0,
367
+ miss_weight: float = 0.1,
368
+ pc_weight: float = 0.3,
369
+ physics_weight: float = 0.2,
370
+ pos_weight: float = 50.0,
371
+ use_focal: bool = False,
372
+ focal_alpha: float = 0.75,
373
+ focal_gamma: float = 2.0,
374
+ ):
375
+ super().__init__()
376
+ self.risk_weight = risk_weight
377
+ self.miss_weight = miss_weight
378
+ self.pc_weight = pc_weight
379
+ self.physics_weight = physics_weight
380
+ if use_focal:
381
+ self.risk_loss = SigmoidFocalLoss(alpha=focal_alpha, gamma=focal_gamma)
382
+ else:
383
+ self.risk_loss = nn.BCEWithLogitsLoss(
384
+ pos_weight=torch.tensor(pos_weight)
385
+ )
386
+ self.miss_loss = nn.MSELoss()
387
+
388
+ def forward(
389
+ self,
390
+ risk_logit: torch.Tensor, # (B, 1)
391
+ miss_pred_log: torch.Tensor, # (B, 1)
392
+ risk_target: torch.Tensor, # (B,)
393
+ miss_target_log: torch.Tensor, # (B,)
394
+ pc_pred_log10: torch.Tensor = None, # (B, 1) predicted log10(Pc)
395
+ pc_target_log10: torch.Tensor = None, # (B,) target log10(Pc)
396
+ moid_log: torch.Tensor = None, # (B,) optional, log1p(MOID_km)
397
+ domain_weight: torch.Tensor = None, # (B,) per-sample weight
398
+ ) -> tuple[torch.Tensor, dict]:
399
+
400
+ # Risk classification loss (BCE with class weighting)
401
+ if domain_weight is not None and not isinstance(self.risk_loss, SigmoidFocalLoss):
402
+ # Per-sample weighted BCE: compute element-wise then weight
403
+ bce_per_sample = F.binary_cross_entropy_with_logits(
404
+ risk_logit.squeeze(-1), risk_target,
405
+ pos_weight=self.risk_loss.pos_weight.to(risk_logit.device),
406
+ reduction="none",
407
+ )
408
+ L_risk = (bce_per_sample * domain_weight).mean()
409
+ else:
410
+ L_risk = self.risk_loss(risk_logit.squeeze(-1), risk_target)
411
+
412
+ # Miss distance regression loss — also domain-weighted
413
+ miss_residual = (miss_pred_log.squeeze(-1) - miss_target_log) ** 2
414
+ if domain_weight is not None:
415
+ L_miss = (miss_residual * domain_weight).mean()
416
+ else:
417
+ L_miss = miss_residual.mean()
418
+
419
+ # Collision probability regression loss
420
+ L_pc = torch.tensor(0.0, device=risk_logit.device)
421
+ if pc_pred_log10 is not None and pc_target_log10 is not None:
422
+ pc_residual = (pc_pred_log10.squeeze(-1) - pc_target_log10) ** 2
423
+ if domain_weight is not None:
424
+ L_pc = (pc_residual * domain_weight).mean()
425
+ else:
426
+ L_pc = pc_residual.mean()
427
+
428
+ # Physics constraint: predicted miss >= MOID
429
+ L_physics = torch.tensor(0.0, device=risk_logit.device)
430
+ if moid_log is not None:
431
+ # Violation = how much below MOID the prediction is
432
+ violation = F.relu(moid_log - miss_pred_log.squeeze(-1))
433
+ L_physics = violation.mean()
434
+
435
+ total = (self.risk_weight * L_risk
436
+ + self.miss_weight * L_miss
437
+ + self.pc_weight * L_pc
438
+ + self.physics_weight * L_physics)
439
+
440
+ metrics = {
441
+ "loss": total.item(),
442
+ "risk_loss": L_risk.item(),
443
+ "miss_loss": L_miss.item(),
444
+ "pc_loss": L_pc.item(),
445
+ "physics_loss": L_physics.item(),
446
+ }
447
+
448
+ return total, metrics
src/model/pretrain.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by Claude Code -- 2026-02-10
2
+ """Self-supervised pre-training for the PI-TFT encoder.
3
+
4
+ Masked Feature Reconstruction: mask 60% of CDM temporal features at random
5
+ per timestep, train the Transformer encoder to reconstruct them. This forces
6
+ the model to learn feature correlations, temporal dynamics, and
7
+ static-temporal interactions from ALL CDM data (no labels needed).
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from src.model.deep import PhysicsInformedTFT
14
+
15
+
16
+ class CDMMaskingStrategy(nn.Module):
17
+ """Randomly mask temporal features per timestep for reconstruction pre-training.
18
+
19
+ For each real timestep (respecting padding mask), replaces a fraction of the
20
+ temporal features with a learnable [MASK] token.
21
+ """
22
+
23
+ def __init__(self, n_temporal_features: int, mask_ratio: float = 0.6):
24
+ super().__init__()
25
+ self.n_temporal_features = n_temporal_features
26
+ self.mask_ratio = mask_ratio
27
+ # Learnable [MASK] token — one value per temporal feature
28
+ self.mask_token = nn.Parameter(torch.zeros(n_temporal_features))
29
+ nn.init.normal_(self.mask_token, std=0.02)
30
+
31
+ def forward(
32
+ self,
33
+ temporal: torch.Tensor, # (B, S, F_t)
34
+ padding_mask: torch.Tensor, # (B, S) True=real, False=padding
35
+ ) -> tuple[torch.Tensor, torch.Tensor]:
36
+ """Apply random feature masking.
37
+
38
+ Returns:
39
+ masked_temporal: (B, S, F_t) with masked positions replaced by mask_token
40
+ feature_mask: (B, S, F_t) bool — True where features were masked
41
+ """
42
+ B, S, F = temporal.shape
43
+
44
+ # Generate random mask: True = masked (to reconstruct)
45
+ feature_mask = torch.rand(B, S, F, device=temporal.device) < self.mask_ratio
46
+
47
+ # Only mask real timesteps (not padding)
48
+ feature_mask = feature_mask & padding_mask.unsqueeze(-1)
49
+
50
+ # Replace masked positions with learnable mask token
51
+ masked_temporal = temporal.clone()
52
+ masked_temporal[feature_mask] = self.mask_token.expand(B, S, -1)[feature_mask]
53
+
54
+ return masked_temporal, feature_mask
55
+
56
+
57
+ class MaskedReconstructionHead(nn.Module):
58
+ """Lightweight 2-layer MLP decoder for feature reconstruction.
59
+
60
+ Intentionally small to force the encoder (not the decoder) to learn
61
+ rich representations.
62
+ """
63
+
64
+ def __init__(self, d_model: int, n_temporal_features: int, dropout: float = 0.1):
65
+ super().__init__()
66
+ self.decoder = nn.Sequential(
67
+ nn.LayerNorm(d_model),
68
+ nn.Linear(d_model, d_model),
69
+ nn.GELU(),
70
+ nn.Dropout(dropout),
71
+ nn.Linear(d_model, n_temporal_features),
72
+ )
73
+
74
+ def forward(self, hidden: torch.Tensor) -> torch.Tensor:
75
+ """Reconstruct temporal features from encoder hidden states.
76
+
77
+ Args:
78
+ hidden: (B, S, D) per-timestep encoder output
79
+
80
+ Returns:
81
+ reconstructed: (B, S, F_t) reconstructed temporal features
82
+ """
83
+ return self.decoder(hidden)
84
+
85
+
86
+ class PretrainingWrapper(nn.Module):
87
+ """Wraps PI-TFT encoder with masking strategy and reconstruction head.
88
+
89
+ Forward pass: generate mask → apply mask token → encode_sequence() →
90
+ reconstruct → return reconstructed + masks.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ n_temporal_features: int,
96
+ n_static_features: int,
97
+ d_model: int = 128,
98
+ n_heads: int = 4,
99
+ n_layers: int = 2,
100
+ dropout: float = 0.15,
101
+ mask_ratio: float = 0.6,
102
+ ):
103
+ super().__init__()
104
+ self.encoder = PhysicsInformedTFT(
105
+ n_temporal_features=n_temporal_features,
106
+ n_static_features=n_static_features,
107
+ d_model=d_model,
108
+ n_heads=n_heads,
109
+ n_layers=n_layers,
110
+ dropout=dropout,
111
+ )
112
+ self.masking = CDMMaskingStrategy(n_temporal_features, mask_ratio)
113
+ self.reconstruction_head = MaskedReconstructionHead(
114
+ d_model, n_temporal_features, dropout
115
+ )
116
+
117
+ def forward(
118
+ self,
119
+ temporal: torch.Tensor, # (B, S, F_t)
120
+ static: torch.Tensor, # (B, F_s)
121
+ time_to_tca: torch.Tensor, # (B, S, 1)
122
+ mask: torch.Tensor, # (B, S) True=real
123
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
124
+ """
125
+ Returns:
126
+ reconstructed: (B, S, F_t) reconstructed temporal features
127
+ feature_mask: (B, S, F_t) bool — True where features were masked
128
+ original: (B, S, F_t) original temporal features (for loss computation)
129
+ """
130
+ original = temporal.clone()
131
+
132
+ # Mask temporal features
133
+ masked_temporal, feature_mask = self.masking(temporal, mask)
134
+
135
+ # Encode masked sequence
136
+ hidden, _ = self.encoder.encode_sequence(
137
+ masked_temporal, static, time_to_tca, mask
138
+ )
139
+
140
+ # Reconstruct
141
+ reconstructed = self.reconstruction_head(hidden)
142
+
143
+ return reconstructed, feature_mask, original
144
+
145
+
146
+ class PretrainingLoss(nn.Module):
147
+ """MSE loss computed only on masked positions."""
148
+
149
+ def forward(
150
+ self,
151
+ reconstructed: torch.Tensor, # (B, S, F_t)
152
+ original: torch.Tensor, # (B, S, F_t)
153
+ feature_mask: torch.Tensor, # (B, S, F_t) bool
154
+ ) -> tuple[torch.Tensor, dict]:
155
+ # MSE on masked positions only
156
+ masked_diff = (reconstructed - original) ** 2
157
+ masked_diff = masked_diff[feature_mask]
158
+
159
+ if masked_diff.numel() == 0:
160
+ loss = torch.tensor(0.0, device=reconstructed.device, requires_grad=True)
161
+ else:
162
+ loss = masked_diff.mean()
163
+
164
+ return loss, {"reconstruction_loss": loss.item()}
src/model/triage.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by Claude Code -- 2026-02-13
2
+ """Urgency tier classifier for conjunction events."""
3
+
4
+ from enum import Enum
5
+ from dataclasses import dataclass
6
+
7
+
8
+ class UrgencyTier(str, Enum):
9
+ LOW = "LOW"
10
+ MODERATE = "MODERATE"
11
+ HIGH = "HIGH"
12
+
13
+
14
+ @dataclass
15
+ class TriageResult:
16
+ tier: UrgencyTier
17
+ color: str
18
+ recommendation: str
19
+ risk_probability: float
20
+
21
+
22
+ def classify_urgency(risk_prob: float) -> TriageResult:
23
+ """Classify conjunction urgency based on predicted risk probability.
24
+
25
+ Tiers:
26
+ LOW (risk <= 0.10): Monitor only
27
+ MODERATE (0.10 < risk <= 0.40): Assess maneuver options
28
+ HIGH (risk > 0.40): Immediate action required
29
+ """
30
+ if risk_prob <= 0.10:
31
+ return TriageResult(
32
+ tier=UrgencyTier.LOW,
33
+ color="#4fff8a",
34
+ recommendation="Monitor conjunction. No action required.",
35
+ risk_probability=risk_prob,
36
+ )
37
+ elif risk_prob <= 0.40:
38
+ return TriageResult(
39
+ tier=UrgencyTier.MODERATE,
40
+ color="#ffb84f",
41
+ recommendation="Assess maneuver options. Increased monitoring recommended.",
42
+ risk_probability=risk_prob,
43
+ )
44
+ else:
45
+ return TriageResult(
46
+ tier=UrgencyTier.HIGH,
47
+ color="#ff4f5a",
48
+ recommendation="Immediate action required. Initiate collision avoidance maneuver.",
49
+ risk_probability=risk_prob,
50
+ )
src/utils/__init__.py ADDED
File without changes