Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- Dockerfile +30 -0
- README.md +26 -5
- app/__init__.py +1 -0
- app/main.py +481 -0
- app_wrapper.py +80 -0
- requirements.txt +7 -0
- results/deep_model_results.json +670 -0
- results/model_comparison.json +40 -0
- results/staleness_experiment.json +355 -0
- src/__init__.py +0 -0
- src/data/__init__.py +0 -0
- src/data/augment.py +297 -0
- src/data/cdm_loader.py +205 -0
- src/data/counterfactual.py +458 -0
- src/data/density_features.py +259 -0
- src/data/firebase_client.py +159 -0
- src/data/maneuver_classifier.py +143 -0
- src/data/maneuver_detector.py +205 -0
- src/data/merge_sources.py +270 -0
- src/data/sequence_builder.py +497 -0
- src/data/spacetrack_crossref.py +185 -0
- src/evaluation/__init__.py +0 -0
- src/evaluation/conformal.py +307 -0
- src/evaluation/metrics.py +128 -0
- src/evaluation/staleness.py +263 -0
- src/model/__init__.py +0 -0
- src/model/baseline.py +107 -0
- src/model/classical.py +115 -0
- src/model/deep.py +448 -0
- src/model/pretrain.py +164 -0
- src/model/triage.py +50 -0
- src/utils/__init__.py +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends gcc && rm -rf /var/lib/apt/lists/*
|
| 7 |
+
|
| 8 |
+
# Copy requirements first for better caching
|
| 9 |
+
COPY requirements.txt .
|
| 10 |
+
|
| 11 |
+
# Install CPU-only PyTorch first (smaller)
|
| 12 |
+
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
|
| 13 |
+
|
| 14 |
+
# Install remaining dependencies
|
| 15 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 16 |
+
|
| 17 |
+
# Copy application code
|
| 18 |
+
COPY app/ ./app/
|
| 19 |
+
COPY src/ ./src/
|
| 20 |
+
COPY results/ ./results/
|
| 21 |
+
COPY app_wrapper.py .
|
| 22 |
+
|
| 23 |
+
# Create models directory (will be populated at runtime)
|
| 24 |
+
RUN mkdir -p models
|
| 25 |
+
|
| 26 |
+
# HF Spaces expects port 7860
|
| 27 |
+
EXPOSE 7860
|
| 28 |
+
|
| 29 |
+
# Run the wrapper that downloads models then starts uvicorn
|
| 30 |
+
CMD ["python", "app_wrapper.py"]
|
README.md
CHANGED
|
@@ -1,10 +1,31 @@
|
|
| 1 |
---
|
| 2 |
-
title: Panacea
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
colorTo: gray
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Panacea Satellite Collision Avoidance API
|
| 3 |
+
colorFrom: indigo
|
| 4 |
+
colorTo: blue
|
|
|
|
| 5 |
sdk: docker
|
| 6 |
+
app_port: 7860
|
| 7 |
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# Panacea -- Satellite Collision Avoidance API
|
| 12 |
+
|
| 13 |
+
FastAPI backend for the Panacea satellite collision avoidance system.
|
| 14 |
+
|
| 15 |
+
## Endpoints
|
| 16 |
+
|
| 17 |
+
- `GET /api/health` -- Health check, lists loaded models
|
| 18 |
+
- `POST /api/predict-conjunction` -- Run inference on a CDM sequence
|
| 19 |
+
- `GET /api/model-comparison` -- Pre-computed model comparison results
|
| 20 |
+
- `GET /api/experiment-results` -- Staleness experiment results
|
| 21 |
+
- `POST /api/bulk-screen` -- Screen TLE pairs for potential conjunctions
|
| 22 |
+
|
| 23 |
+
## Models
|
| 24 |
+
|
| 25 |
+
Three models are loaded at startup from [DTanzillo/panacea-models](https://huggingface.co/DTanzillo/panacea-models):
|
| 26 |
+
|
| 27 |
+
1. **Baseline** -- Orbital shell density prior (AUC-PR: 0.061)
|
| 28 |
+
2. **XGBoost** -- Classical ML on engineered CDM features (AUC-PR: 0.988)
|
| 29 |
+
3. **PI-TFT** -- Physics-Informed Temporal Fusion Transformer (AUC-PR: 0.511)
|
| 30 |
+
|
| 31 |
+
Built for AIPI 540 (Duke University).
|
app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code -- 2026-02-13
|
app/main.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code -- 2026-02-13
|
| 2 |
+
"""FastAPI backend for Panacea collision avoidance inference."""
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from contextlib import asynccontextmanager
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
from fastapi import FastAPI
|
| 13 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
+
from pydantic import BaseModel
|
| 15 |
+
|
| 16 |
+
import sys
|
| 17 |
+
|
| 18 |
+
ROOT = Path(__file__).parent.parent
|
| 19 |
+
sys.path.insert(0, str(ROOT))
|
| 20 |
+
|
| 21 |
+
from src.model.baseline import OrbitalShellBaseline
|
| 22 |
+
from src.model.classical import XGBoostConjunctionModel
|
| 23 |
+
from src.model.deep import PhysicsInformedTFT
|
| 24 |
+
from src.model.triage import classify_urgency
|
| 25 |
+
from src.data.sequence_builder import TEMPORAL_FEATURES, STATIC_FEATURES, MAX_SEQ_LEN
|
| 26 |
+
|
| 27 |
+
HF_REPO_ID = "DTanzillo/panacea-models"
|
| 28 |
+
|
| 29 |
+
# Global model storage
|
| 30 |
+
models = {}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def download_models_from_hf(model_dir: Path, results_dir: Path):
|
| 34 |
+
"""Download models from HuggingFace Hub if not available locally."""
|
| 35 |
+
try:
|
| 36 |
+
from huggingface_hub import snapshot_download
|
| 37 |
+
token = os.environ.get("HF_TOKEN")
|
| 38 |
+
local = snapshot_download(
|
| 39 |
+
HF_REPO_ID,
|
| 40 |
+
token=token,
|
| 41 |
+
allow_patterns=["models/*", "results/*"],
|
| 42 |
+
)
|
| 43 |
+
local = Path(local)
|
| 44 |
+
# Copy files to expected locations
|
| 45 |
+
for src in (local / "models").iterdir():
|
| 46 |
+
dst = model_dir / src.name
|
| 47 |
+
if not dst.exists():
|
| 48 |
+
import shutil
|
| 49 |
+
shutil.copy2(src, dst)
|
| 50 |
+
print(f" Downloaded {src.name} from HF Hub")
|
| 51 |
+
for src in (local / "results").iterdir():
|
| 52 |
+
dst = results_dir / src.name
|
| 53 |
+
if not dst.exists():
|
| 54 |
+
import shutil
|
| 55 |
+
shutil.copy2(src, dst)
|
| 56 |
+
print(f" Downloaded {src.name} from HF Hub")
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f" HF Hub download skipped: {e}")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def load_models():
|
| 62 |
+
"""Load all 3 models at startup. Downloads from HF Hub if missing."""
|
| 63 |
+
model_dir = ROOT / "models"
|
| 64 |
+
results_dir = ROOT / "results"
|
| 65 |
+
model_dir.mkdir(exist_ok=True)
|
| 66 |
+
results_dir.mkdir(exist_ok=True)
|
| 67 |
+
|
| 68 |
+
# Try downloading from HF Hub if local models are missing
|
| 69 |
+
if not (model_dir / "baseline.json").exists():
|
| 70 |
+
print(" Local models not found, trying HuggingFace Hub...")
|
| 71 |
+
download_models_from_hf(model_dir, results_dir)
|
| 72 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 73 |
+
|
| 74 |
+
baseline_path = model_dir / "baseline.json"
|
| 75 |
+
if baseline_path.exists():
|
| 76 |
+
models["baseline"] = OrbitalShellBaseline.load(baseline_path)
|
| 77 |
+
print(" Loaded baseline model")
|
| 78 |
+
|
| 79 |
+
xgboost_path = model_dir / "xgboost.pkl"
|
| 80 |
+
if xgboost_path.exists():
|
| 81 |
+
models["xgboost"] = XGBoostConjunctionModel.load(xgboost_path)
|
| 82 |
+
print(" Loaded XGBoost model")
|
| 83 |
+
|
| 84 |
+
pitft_path = model_dir / "transformer.pt"
|
| 85 |
+
if pitft_path.exists():
|
| 86 |
+
checkpoint = torch.load(pitft_path, map_location=device, weights_only=False)
|
| 87 |
+
config = checkpoint["config"]
|
| 88 |
+
|
| 89 |
+
model = PhysicsInformedTFT(
|
| 90 |
+
n_temporal_features=config["n_temporal"],
|
| 91 |
+
n_static_features=config["n_static"],
|
| 92 |
+
d_model=config.get("d_model", 128),
|
| 93 |
+
n_heads=config.get("n_heads", 4),
|
| 94 |
+
n_layers=config.get("n_layers", 2),
|
| 95 |
+
).to(device)
|
| 96 |
+
# strict=False for backward compat: old checkpoints lack pc_head weights
|
| 97 |
+
model.load_state_dict(checkpoint["model_state"], strict=False)
|
| 98 |
+
model.eval()
|
| 99 |
+
|
| 100 |
+
models["pitft"] = model
|
| 101 |
+
models["pitft_checkpoint"] = checkpoint
|
| 102 |
+
models["pitft_device"] = device
|
| 103 |
+
temp = checkpoint.get("temperature", 1.0)
|
| 104 |
+
has_pc = checkpoint.get("has_pc_head", False)
|
| 105 |
+
print(f" Loaded PI-TFT (epoch {checkpoint['epoch']}, T={temp:.3f}, pc_head={'yes' if has_pc else 'no'})")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@asynccontextmanager
|
| 109 |
+
async def lifespan(app: FastAPI):
|
| 110 |
+
print("Loading models ...")
|
| 111 |
+
load_models()
|
| 112 |
+
loaded = [k for k in models if not k.startswith("pitft_")]
|
| 113 |
+
print(f"Models loaded: {loaded}")
|
| 114 |
+
yield
|
| 115 |
+
models.clear()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
app = FastAPI(
|
| 119 |
+
title="Panacea — Satellite Collision Avoidance API",
|
| 120 |
+
version="1.0.0",
|
| 121 |
+
lifespan=lifespan,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
app.add_middleware(
|
| 125 |
+
CORSMiddleware,
|
| 126 |
+
allow_origins=["*"],
|
| 127 |
+
allow_credentials=True,
|
| 128 |
+
allow_methods=["*"],
|
| 129 |
+
allow_headers=["*"],
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# --- Pydantic models ---
|
| 134 |
+
|
| 135 |
+
class CDMFeatures(BaseModel):
|
| 136 |
+
"""A sequence of CDM feature snapshots for one conjunction event."""
|
| 137 |
+
event_id: Optional[int] = None
|
| 138 |
+
cdm_sequence: list[dict]
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class BulkScreenRequest(BaseModel):
|
| 142 |
+
"""TLE data for pairwise screening."""
|
| 143 |
+
tles: list[dict]
|
| 144 |
+
top_k: int = 10
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# --- Endpoints ---
|
| 148 |
+
|
| 149 |
+
@app.get("/api/health")
|
| 150 |
+
async def health():
|
| 151 |
+
loaded = []
|
| 152 |
+
if "baseline" in models:
|
| 153 |
+
loaded.append("baseline")
|
| 154 |
+
if "xgboost" in models:
|
| 155 |
+
loaded.append("xgboost")
|
| 156 |
+
if "pitft" in models:
|
| 157 |
+
loaded.append("pitft")
|
| 158 |
+
|
| 159 |
+
device = str(models.get("pitft_device", "cpu"))
|
| 160 |
+
return {
|
| 161 |
+
"status": "healthy",
|
| 162 |
+
"models_loaded": loaded,
|
| 163 |
+
"device": device,
|
| 164 |
+
"n_models": len(loaded),
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
@app.post("/api/predict-conjunction")
|
| 169 |
+
async def predict_conjunction(features: CDMFeatures):
|
| 170 |
+
"""Run inference on a single conjunction event across all loaded models."""
|
| 171 |
+
results = {}
|
| 172 |
+
cdm_seq = features.cdm_sequence
|
| 173 |
+
if not cdm_seq:
|
| 174 |
+
return {"error": "Empty CDM sequence"}
|
| 175 |
+
|
| 176 |
+
last_cdm = cdm_seq[-1]
|
| 177 |
+
altitude = last_cdm.get("t_h_apo", last_cdm.get("c_h_apo", 500.0))
|
| 178 |
+
|
| 179 |
+
# Baseline prediction
|
| 180 |
+
if "baseline" in models:
|
| 181 |
+
risk_probs, miss_preds = models["baseline"].predict(np.array([altitude]))
|
| 182 |
+
triage = classify_urgency(float(risk_probs[0]))
|
| 183 |
+
results["baseline"] = {
|
| 184 |
+
"risk_probability": float(risk_probs[0]),
|
| 185 |
+
"miss_distance_km": float(np.expm1(miss_preds[0])),
|
| 186 |
+
"triage": {
|
| 187 |
+
"tier": triage.tier.value,
|
| 188 |
+
"color": triage.color,
|
| 189 |
+
"recommendation": triage.recommendation,
|
| 190 |
+
},
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
# XGBoost prediction
|
| 194 |
+
if "xgboost" in models:
|
| 195 |
+
xgb_features = _build_xgboost_features(cdm_seq)
|
| 196 |
+
risk_probs, miss_km = models["xgboost"].predict(xgb_features)
|
| 197 |
+
triage = classify_urgency(float(risk_probs[0]))
|
| 198 |
+
results["xgboost"] = {
|
| 199 |
+
"risk_probability": float(risk_probs[0]),
|
| 200 |
+
"miss_distance_km": float(miss_km[0]),
|
| 201 |
+
"triage": {
|
| 202 |
+
"tier": triage.tier.value,
|
| 203 |
+
"color": triage.color,
|
| 204 |
+
"recommendation": triage.recommendation,
|
| 205 |
+
},
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
# PI-TFT prediction
|
| 209 |
+
if "pitft" in models:
|
| 210 |
+
risk_prob, miss_log, pc_log10 = _run_pitft_inference(cdm_seq)
|
| 211 |
+
triage = classify_urgency(risk_prob)
|
| 212 |
+
results["pitft"] = {
|
| 213 |
+
"risk_probability": risk_prob,
|
| 214 |
+
"miss_distance_km": float(np.expm1(miss_log)),
|
| 215 |
+
"collision_probability": float(10 ** pc_log10),
|
| 216 |
+
"collision_probability_log10": pc_log10,
|
| 217 |
+
"triage": {
|
| 218 |
+
"tier": triage.tier.value,
|
| 219 |
+
"color": triage.color,
|
| 220 |
+
"recommendation": triage.recommendation,
|
| 221 |
+
},
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
return results
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
@app.get("/api/model-comparison")
|
| 228 |
+
async def model_comparison():
|
| 229 |
+
"""Return pre-computed model comparison results."""
|
| 230 |
+
results = []
|
| 231 |
+
|
| 232 |
+
comparison_path = ROOT / "results" / "model_comparison.json"
|
| 233 |
+
if comparison_path.exists():
|
| 234 |
+
with open(comparison_path) as f:
|
| 235 |
+
results = json.load(f)
|
| 236 |
+
|
| 237 |
+
deep_path = ROOT / "results" / "deep_model_results.json"
|
| 238 |
+
if deep_path.exists():
|
| 239 |
+
with open(deep_path) as f:
|
| 240 |
+
deep = json.load(f)
|
| 241 |
+
pitft_entry = {
|
| 242 |
+
"model": deep["model"],
|
| 243 |
+
**deep["test"],
|
| 244 |
+
}
|
| 245 |
+
results.append(pitft_entry)
|
| 246 |
+
|
| 247 |
+
return results
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@app.get("/api/experiment-results")
|
| 251 |
+
async def experiment_results():
|
| 252 |
+
"""Return staleness experiment results."""
|
| 253 |
+
exp_path = ROOT / "results" / "staleness_experiment.json"
|
| 254 |
+
if exp_path.exists():
|
| 255 |
+
with open(exp_path) as f:
|
| 256 |
+
return json.load(f)
|
| 257 |
+
return {"error": "No experiment results found. Run: python scripts/run_experiment.py"}
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
@app.post("/api/bulk-screen")
|
| 261 |
+
async def bulk_screen(request: BulkScreenRequest):
|
| 262 |
+
"""Screen TLE pairs for potential conjunctions using orbital filtering."""
|
| 263 |
+
tles = request.tles
|
| 264 |
+
top_k = request.top_k
|
| 265 |
+
|
| 266 |
+
if len(tles) < 2:
|
| 267 |
+
return {"pairs": [], "n_candidates": 0, "n_total": len(tles)}
|
| 268 |
+
|
| 269 |
+
n = len(tles)
|
| 270 |
+
names = [t.get("OBJECT_NAME", f"Object {i}") for i, t in enumerate(tles)]
|
| 271 |
+
norad_ids = [t.get("NORAD_CAT_ID", 0) for t in tles]
|
| 272 |
+
|
| 273 |
+
# Compute altitude from mean motion: a = (mu / n^2)^(1/3), alt = a - R_earth
|
| 274 |
+
MU = 398600.4418 # km^3/s^2
|
| 275 |
+
R_EARTH = 6371.0 # km
|
| 276 |
+
|
| 277 |
+
mean_motions = np.array([t.get("MEAN_MOTION", 15.0) for t in tles])
|
| 278 |
+
n_rad = mean_motions * 2 * np.pi / 86400.0
|
| 279 |
+
n_rad = np.clip(n_rad, 1e-10, None)
|
| 280 |
+
sma = (MU / (n_rad ** 2)) ** (1.0 / 3.0)
|
| 281 |
+
|
| 282 |
+
eccentricities = np.array([t.get("ECCENTRICITY", 0.0) for t in tles])
|
| 283 |
+
apogee = sma * (1 + eccentricities) - R_EARTH
|
| 284 |
+
perigee = sma * (1 - eccentricities) - R_EARTH
|
| 285 |
+
|
| 286 |
+
raan = np.array([t.get("RA_OF_ASC_NODE", 0.0) for t in tles])
|
| 287 |
+
|
| 288 |
+
# Pairwise filtering via broadcasting
|
| 289 |
+
alt_overlap = ((apogee[:, None] >= perigee[None, :]) &
|
| 290 |
+
(apogee[None, :] >= perigee[:, None]))
|
| 291 |
+
|
| 292 |
+
raan_diff = np.abs(raan[:, None] - raan[None, :])
|
| 293 |
+
raan_diff = np.minimum(raan_diff, 360.0 - raan_diff)
|
| 294 |
+
raan_close = raan_diff < 30.0
|
| 295 |
+
|
| 296 |
+
candidates = alt_overlap & raan_close
|
| 297 |
+
np.fill_diagonal(candidates, False)
|
| 298 |
+
candidates = np.triu(candidates, k=1)
|
| 299 |
+
|
| 300 |
+
pairs_i, pairs_j = np.where(candidates)
|
| 301 |
+
|
| 302 |
+
if len(pairs_i) == 0:
|
| 303 |
+
return {"pairs": [], "n_candidates": 0, "n_total": n}
|
| 304 |
+
|
| 305 |
+
# Score candidates using baseline model
|
| 306 |
+
if "baseline" in models:
|
| 307 |
+
pair_altitudes = (apogee[pairs_i] + apogee[pairs_j]) / 2.0
|
| 308 |
+
risk_scores, miss_estimates = models["baseline"].predict(pair_altitudes)
|
| 309 |
+
else:
|
| 310 |
+
risk_scores = np.ones(len(pairs_i)) * 0.5
|
| 311 |
+
miss_estimates = np.zeros(len(pairs_i))
|
| 312 |
+
|
| 313 |
+
top_indices = np.argsort(-risk_scores)[:top_k]
|
| 314 |
+
|
| 315 |
+
result_pairs = []
|
| 316 |
+
for idx in top_indices:
|
| 317 |
+
i, j = int(pairs_i[idx]), int(pairs_j[idx])
|
| 318 |
+
result_pairs.append({
|
| 319 |
+
"name_1": names[i],
|
| 320 |
+
"name_2": names[j],
|
| 321 |
+
"norad_1": norad_ids[i],
|
| 322 |
+
"norad_2": norad_ids[j],
|
| 323 |
+
"risk_score": float(risk_scores[idx]),
|
| 324 |
+
"altitude_km": float((apogee[i] + apogee[j]) / 2),
|
| 325 |
+
"miss_estimate_km": (float(np.expm1(miss_estimates[idx]))
|
| 326 |
+
if miss_estimates[idx] > 0 else 0.0),
|
| 327 |
+
})
|
| 328 |
+
|
| 329 |
+
return {
|
| 330 |
+
"pairs": result_pairs,
|
| 331 |
+
"n_candidates": int(len(pairs_i)),
|
| 332 |
+
"n_total": n,
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
# --- Helper functions ---
|
| 337 |
+
|
| 338 |
+
def _build_xgboost_features(cdm_sequence: list[dict]) -> np.ndarray:
|
| 339 |
+
"""Build XGBoost feature vector from a CDM sequence (dict format).
|
| 340 |
+
|
| 341 |
+
Replicates events_to_flat_features() logic for a single event.
|
| 342 |
+
"""
|
| 343 |
+
last = cdm_sequence[-1]
|
| 344 |
+
|
| 345 |
+
exclude = {"event_id", "time_to_tca", "risk", "mission_id"}
|
| 346 |
+
feature_keys = sorted([
|
| 347 |
+
k for k in last.keys()
|
| 348 |
+
if isinstance(last.get(k), (int, float)) and k not in exclude
|
| 349 |
+
])
|
| 350 |
+
|
| 351 |
+
base = np.array([float(last.get(k, 0.0)) for k in feature_keys], dtype=np.float32)
|
| 352 |
+
|
| 353 |
+
miss_values = np.array([float(s.get("miss_distance", 0.0)) for s in cdm_sequence])
|
| 354 |
+
risk_values = np.array([float(s.get("risk", -10.0)) for s in cdm_sequence])
|
| 355 |
+
tca_values = np.array([float(s.get("time_to_tca", 0.0)) for s in cdm_sequence])
|
| 356 |
+
|
| 357 |
+
n_cdms = len(cdm_sequence)
|
| 358 |
+
miss_mean = float(np.mean(miss_values))
|
| 359 |
+
miss_std = float(np.std(miss_values)) if n_cdms > 1 else 0.0
|
| 360 |
+
|
| 361 |
+
miss_trend = 0.0
|
| 362 |
+
if n_cdms > 1 and np.std(tca_values) > 0:
|
| 363 |
+
miss_trend = float(np.polyfit(tca_values, miss_values, 1)[0])
|
| 364 |
+
|
| 365 |
+
risk_trend = 0.0
|
| 366 |
+
if n_cdms > 1 and np.std(tca_values) > 0:
|
| 367 |
+
risk_trend = float(np.polyfit(tca_values, risk_values, 1)[0])
|
| 368 |
+
|
| 369 |
+
temporal_feats = np.array([
|
| 370 |
+
n_cdms,
|
| 371 |
+
miss_mean,
|
| 372 |
+
miss_std,
|
| 373 |
+
miss_trend,
|
| 374 |
+
risk_trend,
|
| 375 |
+
float(miss_values[0] - miss_values[-1]) if n_cdms > 1 else 0.0,
|
| 376 |
+
float(last.get("time_to_tca", 0.0)),
|
| 377 |
+
float(last.get("relative_speed", 0.0)),
|
| 378 |
+
], dtype=np.float32)
|
| 379 |
+
|
| 380 |
+
combined = np.concatenate([base, temporal_feats])
|
| 381 |
+
combined = np.nan_to_num(combined, nan=0.0, posinf=0.0, neginf=0.0)
|
| 382 |
+
X = combined.reshape(1, -1)
|
| 383 |
+
|
| 384 |
+
# Pad features if model was trained on augmented data with more columns
|
| 385 |
+
if "xgboost" in models:
|
| 386 |
+
expected = models["xgboost"].scaler.n_features_in_
|
| 387 |
+
if X.shape[1] < expected:
|
| 388 |
+
padding = np.zeros((X.shape[0], expected - X.shape[1]), dtype=X.dtype)
|
| 389 |
+
X = np.hstack([X, padding])
|
| 390 |
+
elif X.shape[1] > expected:
|
| 391 |
+
X = X[:, :expected]
|
| 392 |
+
|
| 393 |
+
return X
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def _run_pitft_inference(cdm_sequence: list[dict]) -> tuple[float, float, float]:
|
| 397 |
+
"""Run PI-TFT inference on a single CDM sequence.
|
| 398 |
+
|
| 399 |
+
Returns: (risk_probability, miss_log)
|
| 400 |
+
"""
|
| 401 |
+
checkpoint = models["pitft_checkpoint"]
|
| 402 |
+
device = models["pitft_device"]
|
| 403 |
+
model = models["pitft"]
|
| 404 |
+
norm = checkpoint["normalization"]
|
| 405 |
+
temperature = checkpoint.get("temperature", 1.0)
|
| 406 |
+
temporal_cols = checkpoint.get("temporal_cols", TEMPORAL_FEATURES)
|
| 407 |
+
static_cols = checkpoint.get("static_cols", STATIC_FEATURES)
|
| 408 |
+
|
| 409 |
+
# Extract temporal features: (S, F_t)
|
| 410 |
+
temporal = np.array([
|
| 411 |
+
[float(cdm.get(col, 0.0)) for col in temporal_cols]
|
| 412 |
+
for cdm in cdm_sequence
|
| 413 |
+
], dtype=np.float32)
|
| 414 |
+
temporal = np.nan_to_num(temporal, nan=0.0, posinf=0.0, neginf=0.0)
|
| 415 |
+
|
| 416 |
+
# Compute deltas
|
| 417 |
+
if len(temporal) > 1:
|
| 418 |
+
deltas = np.diff(temporal, axis=0)
|
| 419 |
+
deltas = np.concatenate(
|
| 420 |
+
[np.zeros((1, deltas.shape[1]), dtype=np.float32), deltas], axis=0
|
| 421 |
+
)
|
| 422 |
+
else:
|
| 423 |
+
deltas = np.zeros_like(temporal)
|
| 424 |
+
|
| 425 |
+
# Normalize
|
| 426 |
+
t_mean = np.array(norm["temporal_mean"], dtype=np.float32)
|
| 427 |
+
t_std = np.array(norm["temporal_std"], dtype=np.float32)
|
| 428 |
+
d_mean = np.array(norm["delta_mean"], dtype=np.float32)
|
| 429 |
+
d_std = np.array(norm["delta_std"], dtype=np.float32)
|
| 430 |
+
s_mean = np.array(norm["static_mean"], dtype=np.float32)
|
| 431 |
+
s_std = np.array(norm["static_std"], dtype=np.float32)
|
| 432 |
+
|
| 433 |
+
temporal = (temporal - t_mean) / t_std
|
| 434 |
+
deltas = (deltas - d_mean) / d_std
|
| 435 |
+
temporal = np.concatenate([temporal, deltas], axis=1)
|
| 436 |
+
|
| 437 |
+
# Static features from last CDM
|
| 438 |
+
last_cdm = cdm_sequence[-1]
|
| 439 |
+
static = np.array(
|
| 440 |
+
[float(last_cdm.get(col, 0.0)) for col in static_cols], dtype=np.float32
|
| 441 |
+
)
|
| 442 |
+
static = np.nan_to_num(static, nan=0.0, posinf=0.0, neginf=0.0)
|
| 443 |
+
static = (static - s_mean) / s_std
|
| 444 |
+
|
| 445 |
+
# Time-to-TCA
|
| 446 |
+
tca_mean = norm["tca_mean"]
|
| 447 |
+
tca_std = norm["tca_std"]
|
| 448 |
+
tca = np.array(
|
| 449 |
+
[float(cdm.get("time_to_tca", 0.0)) for cdm in cdm_sequence], dtype=np.float32
|
| 450 |
+
).reshape(-1, 1)
|
| 451 |
+
tca = (tca - tca_mean) / tca_std
|
| 452 |
+
|
| 453 |
+
# Pad/truncate to MAX_SEQ_LEN
|
| 454 |
+
seq_len = len(temporal)
|
| 455 |
+
if seq_len > MAX_SEQ_LEN:
|
| 456 |
+
temporal = temporal[-MAX_SEQ_LEN:]
|
| 457 |
+
tca = tca[-MAX_SEQ_LEN:]
|
| 458 |
+
seq_len = MAX_SEQ_LEN
|
| 459 |
+
|
| 460 |
+
pad_len = MAX_SEQ_LEN - seq_len
|
| 461 |
+
if pad_len > 0:
|
| 462 |
+
temporal = np.pad(temporal, ((pad_len, 0), (0, 0)), constant_values=0)
|
| 463 |
+
tca = np.pad(tca, ((pad_len, 0), (0, 0)), constant_values=0)
|
| 464 |
+
|
| 465 |
+
mask = np.zeros(MAX_SEQ_LEN, dtype=bool)
|
| 466 |
+
mask[pad_len:] = True
|
| 467 |
+
|
| 468 |
+
# Convert to tensors
|
| 469 |
+
temporal_t = torch.tensor(temporal, dtype=torch.float32).unsqueeze(0).to(device)
|
| 470 |
+
static_t = torch.tensor(static, dtype=torch.float32).unsqueeze(0).to(device)
|
| 471 |
+
tca_t = torch.tensor(tca, dtype=torch.float32).unsqueeze(0).to(device)
|
| 472 |
+
mask_t = torch.tensor(mask, dtype=torch.bool).unsqueeze(0).to(device)
|
| 473 |
+
|
| 474 |
+
with torch.no_grad():
|
| 475 |
+
risk_logit, miss_log, pc_log10, _ = model(temporal_t, static_t, tca_t, mask_t)
|
| 476 |
+
|
| 477 |
+
risk_prob = float(torch.sigmoid(risk_logit / temperature).cpu().item())
|
| 478 |
+
miss_log_val = float(miss_log.cpu().item())
|
| 479 |
+
pc_log10_val = float(pc_log10.cpu().item())
|
| 480 |
+
|
| 481 |
+
return risk_prob, miss_log_val, pc_log10_val
|
app_wrapper.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Startup wrapper for HuggingFace Spaces deployment.
|
| 2 |
+
|
| 3 |
+
Downloads models from DTanzillo/panacea-models on first run,
|
| 4 |
+
then starts the FastAPI application on port 7860.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import shutil
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# Ensure the app root is on the Python path
|
| 13 |
+
ROOT = Path(__file__).parent
|
| 14 |
+
sys.path.insert(0, str(ROOT))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def download_models():
|
| 18 |
+
"""Download models from HuggingFace Hub if not present locally."""
|
| 19 |
+
model_dir = ROOT / "models"
|
| 20 |
+
results_dir = ROOT / "results"
|
| 21 |
+
model_dir.mkdir(exist_ok=True)
|
| 22 |
+
results_dir.mkdir(exist_ok=True)
|
| 23 |
+
|
| 24 |
+
# Check if models already exist
|
| 25 |
+
needed_files = ["baseline.json", "xgboost.pkl", "transformer.pt"]
|
| 26 |
+
all_present = all((model_dir / f).exists() for f in needed_files)
|
| 27 |
+
|
| 28 |
+
if all_present:
|
| 29 |
+
print("Models already present, skipping download.")
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
print("Downloading models from DTanzillo/panacea-models ...")
|
| 33 |
+
try:
|
| 34 |
+
from huggingface_hub import snapshot_download
|
| 35 |
+
|
| 36 |
+
token = os.environ.get("HF_TOKEN")
|
| 37 |
+
local = Path(snapshot_download(
|
| 38 |
+
"DTanzillo/panacea-models",
|
| 39 |
+
token=token,
|
| 40 |
+
allow_patterns=["models/*", "results/*"],
|
| 41 |
+
))
|
| 42 |
+
|
| 43 |
+
# Copy model files
|
| 44 |
+
hf_models = local / "models"
|
| 45 |
+
if hf_models.exists():
|
| 46 |
+
for src_file in hf_models.iterdir():
|
| 47 |
+
dst_file = model_dir / src_file.name
|
| 48 |
+
if not dst_file.exists():
|
| 49 |
+
shutil.copy2(src_file, dst_file)
|
| 50 |
+
print(f" Copied {src_file.name}")
|
| 51 |
+
|
| 52 |
+
# Copy result files (only if missing)
|
| 53 |
+
hf_results = local / "results"
|
| 54 |
+
if hf_results.exists():
|
| 55 |
+
for src_file in hf_results.iterdir():
|
| 56 |
+
dst_file = results_dir / src_file.name
|
| 57 |
+
if not dst_file.exists():
|
| 58 |
+
shutil.copy2(src_file, dst_file)
|
| 59 |
+
print(f" Copied result: {src_file.name}")
|
| 60 |
+
|
| 61 |
+
print("Model download complete.")
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(f"WARNING: Model download failed: {e}")
|
| 64 |
+
print("The API will start but models may not be available.")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
# Step 1: Download models
|
| 69 |
+
download_models()
|
| 70 |
+
|
| 71 |
+
# Step 2: Start uvicorn
|
| 72 |
+
import uvicorn
|
| 73 |
+
port = int(os.environ.get("PORT", 7860))
|
| 74 |
+
print(f"Starting Panacea API on port {port} ...")
|
| 75 |
+
uvicorn.run(
|
| 76 |
+
"app.main:app",
|
| 77 |
+
host="0.0.0.0",
|
| 78 |
+
port=port,
|
| 79 |
+
log_level="info",
|
| 80 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.6
|
| 2 |
+
uvicorn[standard]==0.34.0
|
| 3 |
+
xgboost==2.1.4
|
| 4 |
+
scikit-learn==1.6.1
|
| 5 |
+
pandas==2.2.3
|
| 6 |
+
numpy==2.2.2
|
| 7 |
+
huggingface-hub>=0.27.0
|
results/deep_model_results.json
ADDED
|
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model": "PI-TFT (Physics-Informed Temporal Fusion Transformer)",
|
| 3 |
+
"best_epoch": 36,
|
| 4 |
+
"training_time_minutes": 13.175001474221547,
|
| 5 |
+
"optimal_threshold": 0.2639383375644684,
|
| 6 |
+
"temperature": 0.6179193258285522,
|
| 7 |
+
"use_density": true,
|
| 8 |
+
"test": {
|
| 9 |
+
"loss": 0.021245601093944383,
|
| 10 |
+
"auc_pr": 0.5076785607710974,
|
| 11 |
+
"auc_roc": 0.946749355627952,
|
| 12 |
+
"f1_at_50": 0.0,
|
| 13 |
+
"n_positive": 73,
|
| 14 |
+
"n_total": 2167,
|
| 15 |
+
"pos_rate": 0.03368712589144707,
|
| 16 |
+
"f1": 0.5185185137773299,
|
| 17 |
+
"optimal_threshold": 0.2639383375644684,
|
| 18 |
+
"threshold": 0.2639383375644684,
|
| 19 |
+
"recall_at_prec_30": 0.7808219178082192,
|
| 20 |
+
"recall_at_prec_50": 0.4931506849315068,
|
| 21 |
+
"recall_at_prec_70": 0.2876712328767123,
|
| 22 |
+
"mae_log": 0.10174570232629776,
|
| 23 |
+
"rmse_log": 0.15394317551905587,
|
| 24 |
+
"mae_km": 1533.616943359375,
|
| 25 |
+
"median_abs_error_km": 926.875
|
| 26 |
+
},
|
| 27 |
+
"test_calibrated": {
|
| 28 |
+
"auc_pr": 0.5076785607710974,
|
| 29 |
+
"auc_roc": 0.946749355627952,
|
| 30 |
+
"f1_at_50": 0.0,
|
| 31 |
+
"n_positive": 73,
|
| 32 |
+
"n_total": 2167,
|
| 33 |
+
"pos_rate": 0.03368712589144707,
|
| 34 |
+
"f1": 0.5185185137773299,
|
| 35 |
+
"optimal_threshold": 0.15979407727718353,
|
| 36 |
+
"threshold": 0.15979407727718353,
|
| 37 |
+
"recall_at_prec_30": 0.7808219178082192,
|
| 38 |
+
"recall_at_prec_50": 0.4931506849315068,
|
| 39 |
+
"recall_at_prec_70": 0.2876712328767123
|
| 40 |
+
},
|
| 41 |
+
"history": [
|
| 42 |
+
{
|
| 43 |
+
"epoch": 1,
|
| 44 |
+
"train_loss": 6.801268232190931,
|
| 45 |
+
"val_loss": 5.25680010659354,
|
| 46 |
+
"val_auc_pr": 0.007896529454622946,
|
| 47 |
+
"val_f1": 0.019323671026161646,
|
| 48 |
+
"val_mae_log": 7.11151123046875
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"epoch": 2,
|
| 52 |
+
"train_loss": 3.834329532932591,
|
| 53 |
+
"val_loss": 2.7224787643977573,
|
| 54 |
+
"val_auc_pr": 0.010594418921027337,
|
| 55 |
+
"val_f1": 0.023529411193771638,
|
| 56 |
+
"val_mae_log": 5.041237831115723
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"epoch": 3,
|
| 60 |
+
"train_loss": 1.955074778118649,
|
| 61 |
+
"val_loss": 1.1283516032355172,
|
| 62 |
+
"val_auc_pr": 0.008480584727743306,
|
| 63 |
+
"val_f1": 0.021505376131344667,
|
| 64 |
+
"val_mae_log": 3.112034797668457
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"epoch": 4,
|
| 68 |
+
"train_loss": 0.6309667991625296,
|
| 69 |
+
"val_loss": 0.2000983421291624,
|
| 70 |
+
"val_auc_pr": 0.047413803659580166,
|
| 71 |
+
"val_f1": 0.11764705467128042,
|
| 72 |
+
"val_mae_log": 1.1961653232574463
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"epoch": 5,
|
| 76 |
+
"train_loss": 0.13499785540877163,
|
| 77 |
+
"val_loss": 0.02656353052173342,
|
| 78 |
+
"val_auc_pr": 0.05766442486817594,
|
| 79 |
+
"val_f1": 0.15999999680000007,
|
| 80 |
+
"val_mae_log": 0.29869771003723145
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"epoch": 6,
|
| 84 |
+
"train_loss": 0.07689017317182309,
|
| 85 |
+
"val_loss": 0.02750414184161595,
|
| 86 |
+
"val_auc_pr": 0.134885373440643,
|
| 87 |
+
"val_f1": 0.27272726921487606,
|
| 88 |
+
"val_mae_log": 0.3075650930404663
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"epoch": 7,
|
| 92 |
+
"train_loss": 0.08175783813805193,
|
| 93 |
+
"val_loss": 0.07211375555821828,
|
| 94 |
+
"val_auc_pr": 0.18529914529914526,
|
| 95 |
+
"val_f1": 0.4285714239795918,
|
| 96 |
+
"val_mae_log": 0.6812126040458679
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"epoch": 8,
|
| 100 |
+
"train_loss": 0.07750273872468923,
|
| 101 |
+
"val_loss": 0.027415024914911816,
|
| 102 |
+
"val_auc_pr": 0.13237697916045849,
|
| 103 |
+
"val_f1": 0.3157894698060942,
|
| 104 |
+
"val_mae_log": 0.35104697942733765
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"epoch": 9,
|
| 108 |
+
"train_loss": 0.06653158048520218,
|
| 109 |
+
"val_loss": 0.01911477212394987,
|
| 110 |
+
"val_auc_pr": 0.20693184703085693,
|
| 111 |
+
"val_f1": 0.374999995703125,
|
| 112 |
+
"val_mae_log": 0.2960411608219147
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"epoch": 10,
|
| 116 |
+
"train_loss": 0.0626621154917253,
|
| 117 |
+
"val_loss": 0.020604882389307022,
|
| 118 |
+
"val_auc_pr": 0.3348872180451128,
|
| 119 |
+
"val_f1": 0.5454545404958678,
|
| 120 |
+
"val_mae_log": 0.23688556253910065
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"epoch": 11,
|
| 124 |
+
"train_loss": 0.0617836594581604,
|
| 125 |
+
"val_loss": 0.012763384197439467,
|
| 126 |
+
"val_auc_pr": 0.1294155844155844,
|
| 127 |
+
"val_f1": 0.22222221920438956,
|
| 128 |
+
"val_mae_log": 0.1817978173494339
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"epoch": 12,
|
| 132 |
+
"train_loss": 0.05554375463240856,
|
| 133 |
+
"val_loss": 0.01185049262962171,
|
| 134 |
+
"val_auc_pr": 0.24263038548752833,
|
| 135 |
+
"val_f1": 0.36363635867768596,
|
| 136 |
+
"val_mae_log": 0.15147316455841064
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"epoch": 13,
|
| 140 |
+
"train_loss": 0.05319682077781574,
|
| 141 |
+
"val_loss": 0.017937806567975452,
|
| 142 |
+
"val_auc_pr": 0.2786109128966272,
|
| 143 |
+
"val_f1": 0.33333333055555564,
|
| 144 |
+
"val_mae_log": 0.21772687137126923
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"epoch": 14,
|
| 148 |
+
"train_loss": 0.05603743799634882,
|
| 149 |
+
"val_loss": 0.012255215285612004,
|
| 150 |
+
"val_auc_pr": 0.1654839208410637,
|
| 151 |
+
"val_f1": 0.3076923029585799,
|
| 152 |
+
"val_mae_log": 0.12889182567596436
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"epoch": 15,
|
| 156 |
+
"train_loss": 0.052231158416818926,
|
| 157 |
+
"val_loss": 0.008827194571495056,
|
| 158 |
+
"val_auc_pr": 0.30569487983281085,
|
| 159 |
+
"val_f1": 0.4705882311418686,
|
| 160 |
+
"val_mae_log": 0.11871597170829773
|
| 161 |
+
},
|
| 162 |
+
{
|
| 163 |
+
"epoch": 16,
|
| 164 |
+
"train_loss": 0.050459702796227225,
|
| 165 |
+
"val_loss": 0.006688231070126806,
|
| 166 |
+
"val_auc_pr": 0.3174495864073329,
|
| 167 |
+
"val_f1": 0.33333333055555564,
|
| 168 |
+
"val_mae_log": 0.11670727282762527
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"epoch": 17,
|
| 172 |
+
"train_loss": 0.05048987201943591,
|
| 173 |
+
"val_loss": 0.012136828287371568,
|
| 174 |
+
"val_auc_pr": 0.209023569023569,
|
| 175 |
+
"val_f1": 0.3529411723183391,
|
| 176 |
+
"val_mae_log": 0.15395033359527588
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"epoch": 18,
|
| 180 |
+
"train_loss": 0.05087649694367035,
|
| 181 |
+
"val_loss": 0.007568871269800833,
|
| 182 |
+
"val_auc_pr": 0.2673856209150327,
|
| 183 |
+
"val_f1": 0.3999999962500001,
|
| 184 |
+
"val_mae_log": 0.1411171853542328
|
| 185 |
+
},
|
| 186 |
+
{
|
| 187 |
+
"epoch": 19,
|
| 188 |
+
"train_loss": 0.050642090935159374,
|
| 189 |
+
"val_loss": 0.0066412134495164666,
|
| 190 |
+
"val_auc_pr": 0.27475908192734455,
|
| 191 |
+
"val_f1": 0.3999999955555556,
|
| 192 |
+
"val_mae_log": 0.0915408581495285
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"epoch": 20,
|
| 196 |
+
"train_loss": 0.04991532632628003,
|
| 197 |
+
"val_loss": 0.0055730888686542,
|
| 198 |
+
"val_auc_pr": 0.24940384615384617,
|
| 199 |
+
"val_f1": 0.33333332932098775,
|
| 200 |
+
"val_mae_log": 0.10347151011228561
|
| 201 |
+
},
|
| 202 |
+
{
|
| 203 |
+
"epoch": 21,
|
| 204 |
+
"train_loss": 0.049406778288854133,
|
| 205 |
+
"val_loss": 0.008397463309977735,
|
| 206 |
+
"val_auc_pr": 0.22877207681961503,
|
| 207 |
+
"val_f1": 0.2857142816326531,
|
| 208 |
+
"val_mae_log": 0.15620921552181244
|
| 209 |
+
},
|
| 210 |
+
{
|
| 211 |
+
"epoch": 22,
|
| 212 |
+
"train_loss": 0.04929839575008766,
|
| 213 |
+
"val_loss": 0.0075396452365177015,
|
| 214 |
+
"val_auc_pr": 0.3359158185268243,
|
| 215 |
+
"val_f1": 0.33333333055555564,
|
| 216 |
+
"val_mae_log": 0.11639901250600815
|
| 217 |
+
},
|
| 218 |
+
{
|
| 219 |
+
"epoch": 23,
|
| 220 |
+
"train_loss": 0.04896112705606061,
|
| 221 |
+
"val_loss": 0.007832049591732877,
|
| 222 |
+
"val_auc_pr": 0.3431446821152704,
|
| 223 |
+
"val_f1": 0.36363636012396694,
|
| 224 |
+
"val_mae_log": 0.10894307494163513
|
| 225 |
+
},
|
| 226 |
+
{
|
| 227 |
+
"epoch": 24,
|
| 228 |
+
"train_loss": 0.048813931744646384,
|
| 229 |
+
"val_loss": 0.0061542981836412635,
|
| 230 |
+
"val_auc_pr": 0.3559577677224736,
|
| 231 |
+
"val_f1": 0.36363636012396694,
|
| 232 |
+
"val_mae_log": 0.07847719639539719
|
| 233 |
+
},
|
| 234 |
+
{
|
| 235 |
+
"epoch": 25,
|
| 236 |
+
"train_loss": 0.04768835706888019,
|
| 237 |
+
"val_loss": 0.006223144009709358,
|
| 238 |
+
"val_auc_pr": 0.3659761291340239,
|
| 239 |
+
"val_f1": 0.421052627700831,
|
| 240 |
+
"val_mae_log": 0.14390207827091217
|
| 241 |
+
},
|
| 242 |
+
{
|
| 243 |
+
"epoch": 26,
|
| 244 |
+
"train_loss": 0.04840076712740434,
|
| 245 |
+
"val_loss": 0.0067752449374113765,
|
| 246 |
+
"val_auc_pr": 0.2586657651566374,
|
| 247 |
+
"val_f1": 0.34782608355387534,
|
| 248 |
+
"val_mae_log": 0.1449323147535324
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"epoch": 27,
|
| 252 |
+
"train_loss": 0.047609428044509246,
|
| 253 |
+
"val_loss": 0.0065139371103474075,
|
| 254 |
+
"val_auc_pr": 0.34384112619406737,
|
| 255 |
+
"val_f1": 0.34782608355387534,
|
| 256 |
+
"val_mae_log": 0.09073375165462494
|
| 257 |
+
},
|
| 258 |
+
{
|
| 259 |
+
"epoch": 28,
|
| 260 |
+
"train_loss": 0.04662630880201185,
|
| 261 |
+
"val_loss": 0.006256445976240295,
|
| 262 |
+
"val_auc_pr": 0.33832141293241863,
|
| 263 |
+
"val_f1": 0.33333333055555564,
|
| 264 |
+
"val_mae_log": 0.07596895098686218
|
| 265 |
+
},
|
| 266 |
+
{
|
| 267 |
+
"epoch": 29,
|
| 268 |
+
"train_loss": 0.04634691820152708,
|
| 269 |
+
"val_loss": 0.005017333896830678,
|
| 270 |
+
"val_auc_pr": 0.336514012303486,
|
| 271 |
+
"val_f1": 0.33333333055555564,
|
| 272 |
+
"val_mae_log": 0.07677556574344635
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
"epoch": 30,
|
| 276 |
+
"train_loss": 0.04663669626052315,
|
| 277 |
+
"val_loss": 0.004762223763723991,
|
| 278 |
+
"val_auc_pr": 0.24682988580047405,
|
| 279 |
+
"val_f1": 0.36363636012396694,
|
| 280 |
+
"val_mae_log": 0.08992886543273926
|
| 281 |
+
},
|
| 282 |
+
{
|
| 283 |
+
"epoch": 31,
|
| 284 |
+
"train_loss": 0.046282403110652355,
|
| 285 |
+
"val_loss": 0.003826435888186097,
|
| 286 |
+
"val_auc_pr": 0.2284485407066052,
|
| 287 |
+
"val_f1": 0.3999999962500001,
|
| 288 |
+
"val_mae_log": 0.06141701713204384
|
| 289 |
+
},
|
| 290 |
+
{
|
| 291 |
+
"epoch": 32,
|
| 292 |
+
"train_loss": 0.04575154318197353,
|
| 293 |
+
"val_loss": 0.005115043604746461,
|
| 294 |
+
"val_auc_pr": 0.3611255411255411,
|
| 295 |
+
"val_f1": 0.3999999962500001,
|
| 296 |
+
"val_mae_log": 0.09008380770683289
|
| 297 |
+
},
|
| 298 |
+
{
|
| 299 |
+
"epoch": 33,
|
| 300 |
+
"train_loss": 0.046043931763317135,
|
| 301 |
+
"val_loss": 0.004483342935730304,
|
| 302 |
+
"val_auc_pr": 0.36333333333333334,
|
| 303 |
+
"val_f1": 0.3809523773242631,
|
| 304 |
+
"val_mae_log": 0.10232321172952652
|
| 305 |
+
},
|
| 306 |
+
{
|
| 307 |
+
"epoch": 34,
|
| 308 |
+
"train_loss": 0.04492839058307377,
|
| 309 |
+
"val_loss": 0.007276699944798436,
|
| 310 |
+
"val_auc_pr": 0.3461904761904762,
|
| 311 |
+
"val_f1": 0.3809523773242631,
|
| 312 |
+
"val_mae_log": 0.10686437785625458
|
| 313 |
+
},
|
| 314 |
+
{
|
| 315 |
+
"epoch": 35,
|
| 316 |
+
"train_loss": 0.04576677558188503,
|
| 317 |
+
"val_loss": 0.004259714224774923,
|
| 318 |
+
"val_auc_pr": 0.37718954248366016,
|
| 319 |
+
"val_f1": 0.3999999962500001,
|
| 320 |
+
"val_mae_log": 0.0769796371459961
|
| 321 |
+
},
|
| 322 |
+
{
|
| 323 |
+
"epoch": 36,
|
| 324 |
+
"train_loss": 0.044130372638637956,
|
| 325 |
+
"val_loss": 0.004274079659288483,
|
| 326 |
+
"val_auc_pr": 0.4215151515151515,
|
| 327 |
+
"val_f1": 0.4444444395061729,
|
| 328 |
+
"val_mae_log": 0.09318451583385468
|
| 329 |
+
},
|
| 330 |
+
{
|
| 331 |
+
"epoch": 37,
|
| 332 |
+
"train_loss": 0.04556343443691731,
|
| 333 |
+
"val_loss": 0.0053521015548280305,
|
| 334 |
+
"val_auc_pr": 0.3828373015873016,
|
| 335 |
+
"val_f1": 0.421052627700831,
|
| 336 |
+
"val_mae_log": 0.11446798592805862
|
| 337 |
+
},
|
| 338 |
+
{
|
| 339 |
+
"epoch": 38,
|
| 340 |
+
"train_loss": 0.04497031863476779,
|
| 341 |
+
"val_loss": 0.005016647595246988,
|
| 342 |
+
"val_auc_pr": 0.38186813186813184,
|
| 343 |
+
"val_f1": 0.3809523773242631,
|
| 344 |
+
"val_mae_log": 0.11497646570205688
|
| 345 |
+
},
|
| 346 |
+
{
|
| 347 |
+
"epoch": 39,
|
| 348 |
+
"train_loss": 0.04312905277553442,
|
| 349 |
+
"val_loss": 0.003749881671475513,
|
| 350 |
+
"val_auc_pr": 0.3595238095238095,
|
| 351 |
+
"val_f1": 0.3809523773242631,
|
| 352 |
+
"val_mae_log": 0.05548140034079552
|
| 353 |
+
},
|
| 354 |
+
{
|
| 355 |
+
"epoch": 40,
|
| 356 |
+
"train_loss": 0.04352163130769859,
|
| 357 |
+
"val_loss": 0.005372332009885993,
|
| 358 |
+
"val_auc_pr": 0.3503288825869471,
|
| 359 |
+
"val_f1": 0.34782608355387534,
|
| 360 |
+
"val_mae_log": 0.08230870962142944
|
| 361 |
+
},
|
| 362 |
+
{
|
| 363 |
+
"epoch": 41,
|
| 364 |
+
"train_loss": 0.043740846146200156,
|
| 365 |
+
"val_loss": 0.0039979582319834405,
|
| 366 |
+
"val_auc_pr": 0.41458333333333336,
|
| 367 |
+
"val_f1": 0.3999999962500001,
|
| 368 |
+
"val_mae_log": 0.08734633028507233
|
| 369 |
+
},
|
| 370 |
+
{
|
| 371 |
+
"epoch": 42,
|
| 372 |
+
"train_loss": 0.04409235781310378,
|
| 373 |
+
"val_loss": 0.005109895303446267,
|
| 374 |
+
"val_auc_pr": 0.2524756335282651,
|
| 375 |
+
"val_f1": 0.33333333003472226,
|
| 376 |
+
"val_mae_log": 0.07870446890592575
|
| 377 |
+
},
|
| 378 |
+
{
|
| 379 |
+
"epoch": 43,
|
| 380 |
+
"train_loss": 0.043179894389735685,
|
| 381 |
+
"val_loss": 0.005041864268215639,
|
| 382 |
+
"val_auc_pr": 0.26508912655971484,
|
| 383 |
+
"val_f1": 0.36363636012396694,
|
| 384 |
+
"val_mae_log": 0.07578516006469727
|
| 385 |
+
},
|
| 386 |
+
{
|
| 387 |
+
"epoch": 44,
|
| 388 |
+
"train_loss": 0.04234155755792115,
|
| 389 |
+
"val_loss": 0.0038543779269925187,
|
| 390 |
+
"val_auc_pr": 0.3427519893899204,
|
| 391 |
+
"val_f1": 0.33333333055555564,
|
| 392 |
+
"val_mae_log": 0.06378159672021866
|
| 393 |
+
},
|
| 394 |
+
{
|
| 395 |
+
"epoch": 45,
|
| 396 |
+
"train_loss": 0.043199574021068776,
|
| 397 |
+
"val_loss": 0.00448337330349854,
|
| 398 |
+
"val_auc_pr": 0.38693977591036416,
|
| 399 |
+
"val_f1": 0.36363636012396694,
|
| 400 |
+
"val_mae_log": 0.08112290501594543
|
| 401 |
+
},
|
| 402 |
+
{
|
| 403 |
+
"epoch": 46,
|
| 404 |
+
"train_loss": 0.04324697579282361,
|
| 405 |
+
"val_loss": 0.004593804511906845,
|
| 406 |
+
"val_auc_pr": 0.3657142857142857,
|
| 407 |
+
"val_f1": 0.3809523773242631,
|
| 408 |
+
"val_mae_log": 0.12126877903938293
|
| 409 |
+
},
|
| 410 |
+
{
|
| 411 |
+
"epoch": 47,
|
| 412 |
+
"train_loss": 0.042983541144309814,
|
| 413 |
+
"val_loss": 0.0034202520120223717,
|
| 414 |
+
"val_auc_pr": 0.36703703703703705,
|
| 415 |
+
"val_f1": 0.3809523773242631,
|
| 416 |
+
"val_mae_log": 0.05318637564778328
|
| 417 |
+
},
|
| 418 |
+
{
|
| 419 |
+
"epoch": 48,
|
| 420 |
+
"train_loss": 0.04088504479543583,
|
| 421 |
+
"val_loss": 0.0037384599480511887,
|
| 422 |
+
"val_auc_pr": 0.35812684047978166,
|
| 423 |
+
"val_f1": 0.38461538150887575,
|
| 424 |
+
"val_mae_log": 0.0607416033744812
|
| 425 |
+
},
|
| 426 |
+
{
|
| 427 |
+
"epoch": 49,
|
| 428 |
+
"train_loss": 0.0411647165143812,
|
| 429 |
+
"val_loss": 0.0038923417118244936,
|
| 430 |
+
"val_auc_pr": 0.37444444444444447,
|
| 431 |
+
"val_f1": 0.3809523773242631,
|
| 432 |
+
"val_mae_log": 0.07454186677932739
|
| 433 |
+
},
|
| 434 |
+
{
|
| 435 |
+
"epoch": 50,
|
| 436 |
+
"train_loss": 0.04235347539589212,
|
| 437 |
+
"val_loss": 0.0035431724141484927,
|
| 438 |
+
"val_auc_pr": 0.3718181818181818,
|
| 439 |
+
"val_f1": 0.3809523773242631,
|
| 440 |
+
"val_mae_log": 0.05186235159635544
|
| 441 |
+
},
|
| 442 |
+
{
|
| 443 |
+
"epoch": 51,
|
| 444 |
+
"train_loss": 0.03975096909782371,
|
| 445 |
+
"val_loss": 0.003855357279202768,
|
| 446 |
+
"val_auc_pr": 0.37,
|
| 447 |
+
"val_f1": 0.3809523773242631,
|
| 448 |
+
"val_mae_log": 0.08433445543050766
|
| 449 |
+
},
|
| 450 |
+
{
|
| 451 |
+
"epoch": 52,
|
| 452 |
+
"train_loss": 0.040304526777283564,
|
| 453 |
+
"val_loss": 0.003954493274380054,
|
| 454 |
+
"val_auc_pr": 0.36705882352941177,
|
| 455 |
+
"val_f1": 0.36363636012396694,
|
| 456 |
+
"val_mae_log": 0.0650041252374649
|
| 457 |
+
},
|
| 458 |
+
{
|
| 459 |
+
"epoch": 53,
|
| 460 |
+
"train_loss": 0.041316902365636184,
|
| 461 |
+
"val_loss": 0.0044658422370308214,
|
| 462 |
+
"val_auc_pr": 0.37444444444444447,
|
| 463 |
+
"val_f1": 0.39999999680000003,
|
| 464 |
+
"val_mae_log": 0.08514165133237839
|
| 465 |
+
},
|
| 466 |
+
{
|
| 467 |
+
"epoch": 54,
|
| 468 |
+
"train_loss": 0.041085500773545856,
|
| 469 |
+
"val_loss": 0.003584100299381784,
|
| 470 |
+
"val_auc_pr": 0.36991596638655466,
|
| 471 |
+
"val_f1": 0.36363636012396694,
|
| 472 |
+
"val_mae_log": 0.04943912476301193
|
| 473 |
+
},
|
| 474 |
+
{
|
| 475 |
+
"epoch": 55,
|
| 476 |
+
"train_loss": 0.04048956327543066,
|
| 477 |
+
"val_loss": 0.003669723236401166,
|
| 478 |
+
"val_auc_pr": 0.366961926961927,
|
| 479 |
+
"val_f1": 0.34782608355387534,
|
| 480 |
+
"val_mae_log": 0.0743192732334137
|
| 481 |
+
},
|
| 482 |
+
{
|
| 483 |
+
"epoch": 56,
|
| 484 |
+
"train_loss": 0.04016674624101536,
|
| 485 |
+
"val_loss": 0.004304527521266469,
|
| 486 |
+
"val_auc_pr": 0.3745588235294118,
|
| 487 |
+
"val_f1": 0.39999999680000003,
|
| 488 |
+
"val_mae_log": 0.08440288156270981
|
| 489 |
+
}
|
| 490 |
+
],
|
| 491 |
+
"conformal": {
|
| 492 |
+
"alpha_0.01": {
|
| 493 |
+
"conformal_metrics": {
|
| 494 |
+
"alpha": 0.01,
|
| 495 |
+
"target_coverage": 0.99,
|
| 496 |
+
"marginal_coverage": 0.9700046146746655,
|
| 497 |
+
"coverage_guarantee_met": false,
|
| 498 |
+
"avg_set_size": 2.1033687125057683,
|
| 499 |
+
"efficiency": 0.4741578218735579,
|
| 500 |
+
"positive_coverage": 0.136986301369863,
|
| 501 |
+
"negative_coverage": 0.9990448901623686,
|
| 502 |
+
"set_size_distribution": {
|
| 503 |
+
"2": 1948,
|
| 504 |
+
"3": 214,
|
| 505 |
+
"4": 5
|
| 506 |
+
},
|
| 507 |
+
"n_test": 2167,
|
| 508 |
+
"mean_interval_width": 0.35249775648117065,
|
| 509 |
+
"median_interval_width": 0.3299492597579956
|
| 510 |
+
},
|
| 511 |
+
"conformal_state": {
|
| 512 |
+
"is_calibrated": true,
|
| 513 |
+
"alpha": 0.01,
|
| 514 |
+
"q_hat": 0.31530878875241947,
|
| 515 |
+
"q_residual": 0.31530878875241947,
|
| 516 |
+
"n_cal": 527,
|
| 517 |
+
"tiers": {
|
| 518 |
+
"LOW": [
|
| 519 |
+
0.0,
|
| 520 |
+
0.1
|
| 521 |
+
],
|
| 522 |
+
"MODERATE": [
|
| 523 |
+
0.1,
|
| 524 |
+
0.4
|
| 525 |
+
],
|
| 526 |
+
"HIGH": [
|
| 527 |
+
0.4,
|
| 528 |
+
0.7
|
| 529 |
+
],
|
| 530 |
+
"CRITICAL": [
|
| 531 |
+
0.7,
|
| 532 |
+
1.0
|
| 533 |
+
]
|
| 534 |
+
}
|
| 535 |
+
}
|
| 536 |
+
},
|
| 537 |
+
"alpha_0.05": {
|
| 538 |
+
"conformal_metrics": {
|
| 539 |
+
"alpha": 0.05,
|
| 540 |
+
"target_coverage": 0.95,
|
| 541 |
+
"marginal_coverage": 0.9487771112136595,
|
| 542 |
+
"coverage_guarantee_met": true,
|
| 543 |
+
"avg_set_size": 1.9856945085371482,
|
| 544 |
+
"efficiency": 0.503576372865713,
|
| 545 |
+
"positive_coverage": 0.0,
|
| 546 |
+
"negative_coverage": 0.9818529130850048,
|
| 547 |
+
"set_size_distribution": {
|
| 548 |
+
"1": 31,
|
| 549 |
+
"2": 2136
|
| 550 |
+
},
|
| 551 |
+
"n_test": 2167,
|
| 552 |
+
"mean_interval_width": 0.14139389991760254,
|
| 553 |
+
"median_interval_width": 0.1266784965991974
|
| 554 |
+
},
|
| 555 |
+
"conformal_state": {
|
| 556 |
+
"is_calibrated": true,
|
| 557 |
+
"alpha": 0.05,
|
| 558 |
+
"q_hat": 0.1120380280677236,
|
| 559 |
+
"q_residual": 0.1120380280677236,
|
| 560 |
+
"n_cal": 527,
|
| 561 |
+
"tiers": {
|
| 562 |
+
"LOW": [
|
| 563 |
+
0.0,
|
| 564 |
+
0.1
|
| 565 |
+
],
|
| 566 |
+
"MODERATE": [
|
| 567 |
+
0.1,
|
| 568 |
+
0.4
|
| 569 |
+
],
|
| 570 |
+
"HIGH": [
|
| 571 |
+
0.4,
|
| 572 |
+
0.7
|
| 573 |
+
],
|
| 574 |
+
"CRITICAL": [
|
| 575 |
+
0.7,
|
| 576 |
+
1.0
|
| 577 |
+
]
|
| 578 |
+
}
|
| 579 |
+
}
|
| 580 |
+
},
|
| 581 |
+
"alpha_0.1": {
|
| 582 |
+
"conformal_metrics": {
|
| 583 |
+
"alpha": 0.1,
|
| 584 |
+
"target_coverage": 0.9,
|
| 585 |
+
"marginal_coverage": 0.9284725426857406,
|
| 586 |
+
"coverage_guarantee_met": true,
|
| 587 |
+
"avg_set_size": 1.103830179972312,
|
| 588 |
+
"efficiency": 0.724042455006922,
|
| 589 |
+
"positive_coverage": 0.0,
|
| 590 |
+
"negative_coverage": 0.9608404966571156,
|
| 591 |
+
"set_size_distribution": {
|
| 592 |
+
"1": 1942,
|
| 593 |
+
"2": 225
|
| 594 |
+
},
|
| 595 |
+
"n_test": 2167,
|
| 596 |
+
"mean_interval_width": 0.060726769268512726,
|
| 597 |
+
"median_interval_width": 0.05510023236274719
|
| 598 |
+
},
|
| 599 |
+
"conformal_state": {
|
| 600 |
+
"is_calibrated": true,
|
| 601 |
+
"alpha": 0.1,
|
| 602 |
+
"q_hat": 0.04045976169647709,
|
| 603 |
+
"q_residual": 0.04045976169647709,
|
| 604 |
+
"n_cal": 527,
|
| 605 |
+
"tiers": {
|
| 606 |
+
"LOW": [
|
| 607 |
+
0.0,
|
| 608 |
+
0.1
|
| 609 |
+
],
|
| 610 |
+
"MODERATE": [
|
| 611 |
+
0.1,
|
| 612 |
+
0.4
|
| 613 |
+
],
|
| 614 |
+
"HIGH": [
|
| 615 |
+
0.4,
|
| 616 |
+
0.7
|
| 617 |
+
],
|
| 618 |
+
"CRITICAL": [
|
| 619 |
+
0.7,
|
| 620 |
+
1.0
|
| 621 |
+
]
|
| 622 |
+
}
|
| 623 |
+
}
|
| 624 |
+
},
|
| 625 |
+
"alpha_0.2": {
|
| 626 |
+
"conformal_metrics": {
|
| 627 |
+
"alpha": 0.2,
|
| 628 |
+
"target_coverage": 0.8,
|
| 629 |
+
"marginal_coverage": 0.9220119981541302,
|
| 630 |
+
"coverage_guarantee_met": true,
|
| 631 |
+
"avg_set_size": 1.054453161052146,
|
| 632 |
+
"efficiency": 0.7363867097369635,
|
| 633 |
+
"positive_coverage": 0.0,
|
| 634 |
+
"negative_coverage": 0.9541547277936963,
|
| 635 |
+
"set_size_distribution": {
|
| 636 |
+
"1": 2049,
|
| 637 |
+
"2": 118
|
| 638 |
+
},
|
| 639 |
+
"n_test": 2167,
|
| 640 |
+
"mean_interval_width": 0.04071307182312012,
|
| 641 |
+
"median_interval_width": 0.039181869477033615
|
| 642 |
+
},
|
| 643 |
+
"conformal_state": {
|
| 644 |
+
"is_calibrated": true,
|
| 645 |
+
"alpha": 0.2,
|
| 646 |
+
"q_hat": 0.024541400479014954,
|
| 647 |
+
"q_residual": 0.024541400479014954,
|
| 648 |
+
"n_cal": 527,
|
| 649 |
+
"tiers": {
|
| 650 |
+
"LOW": [
|
| 651 |
+
0.0,
|
| 652 |
+
0.1
|
| 653 |
+
],
|
| 654 |
+
"MODERATE": [
|
| 655 |
+
0.1,
|
| 656 |
+
0.4
|
| 657 |
+
],
|
| 658 |
+
"HIGH": [
|
| 659 |
+
0.4,
|
| 660 |
+
0.7
|
| 661 |
+
],
|
| 662 |
+
"CRITICAL": [
|
| 663 |
+
0.7,
|
| 664 |
+
1.0
|
| 665 |
+
]
|
| 666 |
+
}
|
| 667 |
+
}
|
| 668 |
+
}
|
| 669 |
+
}
|
| 670 |
+
}
|
results/model_comparison.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"model": "Orbital Shell Baseline",
|
| 4 |
+
"auc_pr": 0.061184346220415166,
|
| 5 |
+
"auc_roc": 0.6374507725922728,
|
| 6 |
+
"f1_at_50": 0.0,
|
| 7 |
+
"n_positive": 73,
|
| 8 |
+
"n_total": 2167,
|
| 9 |
+
"pos_rate": 0.03368712505768343,
|
| 10 |
+
"f1": 0.13223140017211957,
|
| 11 |
+
"optimal_threshold": 0.03237410071942446,
|
| 12 |
+
"threshold": 0.03237410071942446,
|
| 13 |
+
"recall_at_prec_30": 0.0,
|
| 14 |
+
"recall_at_prec_50": 0.0,
|
| 15 |
+
"recall_at_prec_70": 0.0,
|
| 16 |
+
"mae_log": 0.9927019602313063,
|
| 17 |
+
"rmse_log": 1.2867684153860748,
|
| 18 |
+
"mae_km": 10600.126897201788,
|
| 19 |
+
"median_abs_error_km": 7222.8428976622645
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"model": "XGBoost (Engineered Features)",
|
| 23 |
+
"auc_pr": 0.9884220304219559,
|
| 24 |
+
"auc_roc": 0.9995944054114168,
|
| 25 |
+
"f1_at_50": 0.9411764705882353,
|
| 26 |
+
"n_positive": 73,
|
| 27 |
+
"n_total": 2167,
|
| 28 |
+
"pos_rate": 0.03368712505768343,
|
| 29 |
+
"f1": 0.9473684160604224,
|
| 30 |
+
"optimal_threshold": 0.5539590716362,
|
| 31 |
+
"threshold": 0.5539590716362,
|
| 32 |
+
"recall_at_prec_30": 1.0,
|
| 33 |
+
"recall_at_prec_50": 1.0,
|
| 34 |
+
"recall_at_prec_70": 1.0,
|
| 35 |
+
"mae_log": 0.011742588180292227,
|
| 36 |
+
"rmse_log": 0.03972278871639667,
|
| 37 |
+
"mae_km": 80.85688587394668,
|
| 38 |
+
"median_abs_error_km": 42.99218749998545
|
| 39 |
+
}
|
| 40 |
+
]
|
results/staleness_experiment.json
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cutoffs": [
|
| 3 |
+
2.0,
|
| 4 |
+
2.5,
|
| 5 |
+
3.0,
|
| 6 |
+
3.5,
|
| 7 |
+
4.0,
|
| 8 |
+
5.0,
|
| 9 |
+
6.0
|
| 10 |
+
],
|
| 11 |
+
"n_test_events": 2167,
|
| 12 |
+
"n_positive": 73,
|
| 13 |
+
"baseline": [
|
| 14 |
+
{
|
| 15 |
+
"auc_pr": 0.061184346220415166,
|
| 16 |
+
"auc_roc": 0.6374507725922728,
|
| 17 |
+
"f1_at_50": 0.0,
|
| 18 |
+
"n_positive": 73,
|
| 19 |
+
"n_total": 2167,
|
| 20 |
+
"pos_rate": 0.03368712505768343,
|
| 21 |
+
"f1": 0.13223140017211957,
|
| 22 |
+
"optimal_threshold": 0.03237410071942446,
|
| 23 |
+
"threshold": 0.03237410071942446,
|
| 24 |
+
"recall_at_prec_30": 0.0,
|
| 25 |
+
"recall_at_prec_50": 0.0,
|
| 26 |
+
"recall_at_prec_70": 0.0,
|
| 27 |
+
"cutoff": 2.0,
|
| 28 |
+
"n_events": 2167
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"auc_pr": 0.061184346220415166,
|
| 32 |
+
"auc_roc": 0.6374507725922728,
|
| 33 |
+
"f1_at_50": 0.0,
|
| 34 |
+
"n_positive": 73,
|
| 35 |
+
"n_total": 2167,
|
| 36 |
+
"pos_rate": 0.03368712505768343,
|
| 37 |
+
"f1": 0.13223140017211957,
|
| 38 |
+
"optimal_threshold": 0.03237410071942446,
|
| 39 |
+
"threshold": 0.03237410071942446,
|
| 40 |
+
"recall_at_prec_30": 0.0,
|
| 41 |
+
"recall_at_prec_50": 0.0,
|
| 42 |
+
"recall_at_prec_70": 0.0,
|
| 43 |
+
"cutoff": 2.5,
|
| 44 |
+
"n_events": 2167
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"auc_pr": 0.061184346220415166,
|
| 48 |
+
"auc_roc": 0.6374507725922728,
|
| 49 |
+
"f1_at_50": 0.0,
|
| 50 |
+
"n_positive": 73,
|
| 51 |
+
"n_total": 2167,
|
| 52 |
+
"pos_rate": 0.03368712505768343,
|
| 53 |
+
"f1": 0.13223140017211957,
|
| 54 |
+
"optimal_threshold": 0.03237410071942446,
|
| 55 |
+
"threshold": 0.03237410071942446,
|
| 56 |
+
"recall_at_prec_30": 0.0,
|
| 57 |
+
"recall_at_prec_50": 0.0,
|
| 58 |
+
"recall_at_prec_70": 0.0,
|
| 59 |
+
"cutoff": 3.0,
|
| 60 |
+
"n_events": 2167
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"auc_pr": 0.061184346220415166,
|
| 64 |
+
"auc_roc": 0.6374507725922728,
|
| 65 |
+
"f1_at_50": 0.0,
|
| 66 |
+
"n_positive": 73,
|
| 67 |
+
"n_total": 2167,
|
| 68 |
+
"pos_rate": 0.03368712505768343,
|
| 69 |
+
"f1": 0.13223140017211957,
|
| 70 |
+
"optimal_threshold": 0.03237410071942446,
|
| 71 |
+
"threshold": 0.03237410071942446,
|
| 72 |
+
"recall_at_prec_30": 0.0,
|
| 73 |
+
"recall_at_prec_50": 0.0,
|
| 74 |
+
"recall_at_prec_70": 0.0,
|
| 75 |
+
"cutoff": 3.5,
|
| 76 |
+
"n_events": 2167
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"auc_pr": 0.061184346220415166,
|
| 80 |
+
"auc_roc": 0.6374507725922728,
|
| 81 |
+
"f1_at_50": 0.0,
|
| 82 |
+
"n_positive": 73,
|
| 83 |
+
"n_total": 2167,
|
| 84 |
+
"pos_rate": 0.03368712505768343,
|
| 85 |
+
"f1": 0.13223140017211957,
|
| 86 |
+
"optimal_threshold": 0.03237410071942446,
|
| 87 |
+
"threshold": 0.03237410071942446,
|
| 88 |
+
"recall_at_prec_30": 0.0,
|
| 89 |
+
"recall_at_prec_50": 0.0,
|
| 90 |
+
"recall_at_prec_70": 0.0,
|
| 91 |
+
"cutoff": 4.0,
|
| 92 |
+
"n_events": 2167
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"auc_pr": 0.061184346220415166,
|
| 96 |
+
"auc_roc": 0.6374507725922728,
|
| 97 |
+
"f1_at_50": 0.0,
|
| 98 |
+
"n_positive": 73,
|
| 99 |
+
"n_total": 2167,
|
| 100 |
+
"pos_rate": 0.03368712505768343,
|
| 101 |
+
"f1": 0.13223140017211957,
|
| 102 |
+
"optimal_threshold": 0.03237410071942446,
|
| 103 |
+
"threshold": 0.03237410071942446,
|
| 104 |
+
"recall_at_prec_30": 0.0,
|
| 105 |
+
"recall_at_prec_50": 0.0,
|
| 106 |
+
"recall_at_prec_70": 0.0,
|
| 107 |
+
"cutoff": 5.0,
|
| 108 |
+
"n_events": 2167
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"auc_pr": 0.061184346220415166,
|
| 112 |
+
"auc_roc": 0.6374507725922728,
|
| 113 |
+
"f1_at_50": 0.0,
|
| 114 |
+
"n_positive": 73,
|
| 115 |
+
"n_total": 2167,
|
| 116 |
+
"pos_rate": 0.03368712505768343,
|
| 117 |
+
"f1": 0.13223140017211957,
|
| 118 |
+
"optimal_threshold": 0.03237410071942446,
|
| 119 |
+
"threshold": 0.03237410071942446,
|
| 120 |
+
"recall_at_prec_30": 0.0,
|
| 121 |
+
"recall_at_prec_50": 0.0,
|
| 122 |
+
"recall_at_prec_70": 0.0,
|
| 123 |
+
"cutoff": 6.0,
|
| 124 |
+
"n_events": 2167
|
| 125 |
+
}
|
| 126 |
+
],
|
| 127 |
+
"xgboost": [
|
| 128 |
+
{
|
| 129 |
+
"auc_pr": 0.9883137899600032,
|
| 130 |
+
"auc_roc": 0.9995878635632139,
|
| 131 |
+
"f1_at_50": 0.935064935064935,
|
| 132 |
+
"n_positive": 73,
|
| 133 |
+
"n_total": 2167,
|
| 134 |
+
"pos_rate": 0.03368712505768343,
|
| 135 |
+
"f1": 0.9411764655987015,
|
| 136 |
+
"optimal_threshold": 0.5284891724586487,
|
| 137 |
+
"threshold": 0.5284891724586487,
|
| 138 |
+
"recall_at_prec_30": 1.0,
|
| 139 |
+
"recall_at_prec_50": 1.0,
|
| 140 |
+
"recall_at_prec_70": 1.0,
|
| 141 |
+
"cutoff": 2.0,
|
| 142 |
+
"n_events": 2167
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"auc_pr": 0.9123203140942627,
|
| 146 |
+
"auc_roc": 0.9903418565869928,
|
| 147 |
+
"f1_at_50": 0.8421052631578947,
|
| 148 |
+
"n_positive": 70,
|
| 149 |
+
"n_total": 2126,
|
| 150 |
+
"pos_rate": 0.03292568203198495,
|
| 151 |
+
"f1": 0.8467153234695509,
|
| 152 |
+
"optimal_threshold": 0.9780168533325195,
|
| 153 |
+
"threshold": 0.9780168533325195,
|
| 154 |
+
"recall_at_prec_30": 0.9857142857142858,
|
| 155 |
+
"recall_at_prec_50": 0.9714285714285714,
|
| 156 |
+
"recall_at_prec_70": 0.9285714285714286,
|
| 157 |
+
"cutoff": 2.5,
|
| 158 |
+
"n_events": 2126
|
| 159 |
+
},
|
| 160 |
+
{
|
| 161 |
+
"auc_pr": 0.7112636105696798,
|
| 162 |
+
"auc_roc": 0.9702624390685601,
|
| 163 |
+
"f1_at_50": 0.7012987012987013,
|
| 164 |
+
"n_positive": 67,
|
| 165 |
+
"n_total": 2045,
|
| 166 |
+
"pos_rate": 0.03276283618581907,
|
| 167 |
+
"f1": 0.722222217246335,
|
| 168 |
+
"optimal_threshold": 0.9061354398727417,
|
| 169 |
+
"threshold": 0.9061354398727417,
|
| 170 |
+
"recall_at_prec_30": 0.9104477611940298,
|
| 171 |
+
"recall_at_prec_50": 0.8507462686567164,
|
| 172 |
+
"recall_at_prec_70": 0.7164179104477612,
|
| 173 |
+
"cutoff": 3.0,
|
| 174 |
+
"n_events": 2045
|
| 175 |
+
},
|
| 176 |
+
{
|
| 177 |
+
"auc_pr": 0.7224173760553306,
|
| 178 |
+
"auc_roc": 0.9779084384250436,
|
| 179 |
+
"f1_at_50": 0.6666666666666666,
|
| 180 |
+
"n_positive": 65,
|
| 181 |
+
"n_total": 1962,
|
| 182 |
+
"pos_rate": 0.033129459734964326,
|
| 183 |
+
"f1": 0.6802721039104078,
|
| 184 |
+
"optimal_threshold": 0.8590014576911926,
|
| 185 |
+
"threshold": 0.8590014576911926,
|
| 186 |
+
"recall_at_prec_30": 0.9384615384615385,
|
| 187 |
+
"recall_at_prec_50": 0.8615384615384616,
|
| 188 |
+
"recall_at_prec_70": 0.6153846153846154,
|
| 189 |
+
"cutoff": 3.5,
|
| 190 |
+
"n_events": 1962
|
| 191 |
+
},
|
| 192 |
+
{
|
| 193 |
+
"auc_pr": 0.6392429519999454,
|
| 194 |
+
"auc_roc": 0.9669743064869061,
|
| 195 |
+
"f1_at_50": 0.5921052631578947,
|
| 196 |
+
"n_positive": 62,
|
| 197 |
+
"n_total": 1890,
|
| 198 |
+
"pos_rate": 0.0328042328042328,
|
| 199 |
+
"f1": 0.6370370320702333,
|
| 200 |
+
"optimal_threshold": 0.8714247941970825,
|
| 201 |
+
"threshold": 0.8714247941970825,
|
| 202 |
+
"recall_at_prec_30": 0.8870967741935484,
|
| 203 |
+
"recall_at_prec_50": 0.8064516129032258,
|
| 204 |
+
"recall_at_prec_70": 0.41935483870967744,
|
| 205 |
+
"cutoff": 4.0,
|
| 206 |
+
"n_events": 1890
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"auc_pr": 0.42295193898950256,
|
| 210 |
+
"auc_roc": 0.9482351744481741,
|
| 211 |
+
"f1_at_50": 0.5419354838709678,
|
| 212 |
+
"n_positive": 58,
|
| 213 |
+
"n_total": 1753,
|
| 214 |
+
"pos_rate": 0.03308613804905876,
|
| 215 |
+
"f1": 0.5454545404630832,
|
| 216 |
+
"optimal_threshold": 0.9965507984161377,
|
| 217 |
+
"threshold": 0.9965507984161377,
|
| 218 |
+
"recall_at_prec_30": 0.7931034482758621,
|
| 219 |
+
"recall_at_prec_50": 0.5689655172413793,
|
| 220 |
+
"recall_at_prec_70": 0.0,
|
| 221 |
+
"cutoff": 5.0,
|
| 222 |
+
"n_events": 1753
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"auc_pr": 0.3219032626600778,
|
| 226 |
+
"auc_roc": 0.9162752848174842,
|
| 227 |
+
"f1_at_50": 0.4027777777777778,
|
| 228 |
+
"n_positive": 55,
|
| 229 |
+
"n_total": 1619,
|
| 230 |
+
"pos_rate": 0.033971587399629403,
|
| 231 |
+
"f1": 0.42592592092764064,
|
| 232 |
+
"optimal_threshold": 0.9984425902366638,
|
| 233 |
+
"threshold": 0.9984425902366638,
|
| 234 |
+
"recall_at_prec_30": 0.5818181818181818,
|
| 235 |
+
"recall_at_prec_50": 0.12727272727272726,
|
| 236 |
+
"recall_at_prec_70": 0.01818181818181818,
|
| 237 |
+
"cutoff": 6.0,
|
| 238 |
+
"n_events": 1619
|
| 239 |
+
}
|
| 240 |
+
],
|
| 241 |
+
"pitft": [
|
| 242 |
+
{
|
| 243 |
+
"auc_pr": 0.5108315323239697,
|
| 244 |
+
"auc_roc": 0.9467689811725608,
|
| 245 |
+
"f1_at_50": 0.0,
|
| 246 |
+
"n_positive": 73,
|
| 247 |
+
"n_total": 2167,
|
| 248 |
+
"pos_rate": 0.03368712505768343,
|
| 249 |
+
"f1": 0.5325443737908337,
|
| 250 |
+
"optimal_threshold": 0.18103967607021332,
|
| 251 |
+
"threshold": 0.18103967607021332,
|
| 252 |
+
"recall_at_prec_30": 0.7808219178082192,
|
| 253 |
+
"recall_at_prec_50": 0.5068493150684932,
|
| 254 |
+
"recall_at_prec_70": 0.2876712328767123,
|
| 255 |
+
"cutoff": 2.0,
|
| 256 |
+
"n_events": 2167
|
| 257 |
+
},
|
| 258 |
+
{
|
| 259 |
+
"auc_pr": 0.40929547300496166,
|
| 260 |
+
"auc_roc": 0.9342620900500278,
|
| 261 |
+
"f1_at_50": 0.028169014084507043,
|
| 262 |
+
"n_positive": 70,
|
| 263 |
+
"n_total": 2126,
|
| 264 |
+
"pos_rate": 0.03292568203198495,
|
| 265 |
+
"f1": 0.45121950730220106,
|
| 266 |
+
"optimal_threshold": 0.18565748631954193,
|
| 267 |
+
"threshold": 0.18565748631954193,
|
| 268 |
+
"recall_at_prec_30": 0.6571428571428571,
|
| 269 |
+
"recall_at_prec_50": 0.35714285714285715,
|
| 270 |
+
"recall_at_prec_70": 0.2,
|
| 271 |
+
"cutoff": 2.5,
|
| 272 |
+
"n_events": 2126
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
"auc_pr": 0.3126159912723518,
|
| 276 |
+
"auc_roc": 0.9086669785551514,
|
| 277 |
+
"f1_at_50": 0.056338028169014086,
|
| 278 |
+
"n_positive": 67,
|
| 279 |
+
"n_total": 2045,
|
| 280 |
+
"pos_rate": 0.03276283618581907,
|
| 281 |
+
"f1": 0.3968253918455531,
|
| 282 |
+
"optimal_threshold": 0.2572215497493744,
|
| 283 |
+
"threshold": 0.2572215497493744,
|
| 284 |
+
"recall_at_prec_30": 0.4626865671641791,
|
| 285 |
+
"recall_at_prec_50": 0.208955223880597,
|
| 286 |
+
"recall_at_prec_70": 0.0,
|
| 287 |
+
"cutoff": 3.0,
|
| 288 |
+
"n_events": 2045
|
| 289 |
+
},
|
| 290 |
+
{
|
| 291 |
+
"auc_pr": 0.32548992974654617,
|
| 292 |
+
"auc_roc": 0.9031263939013017,
|
| 293 |
+
"f1_at_50": 0.058823529411764705,
|
| 294 |
+
"n_positive": 65,
|
| 295 |
+
"n_total": 1962,
|
| 296 |
+
"pos_rate": 0.033129459734964326,
|
| 297 |
+
"f1": 0.3716814110423683,
|
| 298 |
+
"optimal_threshold": 0.28492599725723267,
|
| 299 |
+
"threshold": 0.28492599725723267,
|
| 300 |
+
"recall_at_prec_30": 0.46153846153846156,
|
| 301 |
+
"recall_at_prec_50": 0.26153846153846155,
|
| 302 |
+
"recall_at_prec_70": 0.015384615384615385,
|
| 303 |
+
"cutoff": 3.5,
|
| 304 |
+
"n_events": 1962
|
| 305 |
+
},
|
| 306 |
+
{
|
| 307 |
+
"auc_pr": 0.286925285041537,
|
| 308 |
+
"auc_roc": 0.892249594127197,
|
| 309 |
+
"f1_at_50": 0.0,
|
| 310 |
+
"n_positive": 62,
|
| 311 |
+
"n_total": 1890,
|
| 312 |
+
"pos_rate": 0.0328042328042328,
|
| 313 |
+
"f1": 0.3736263691341626,
|
| 314 |
+
"optimal_threshold": 0.16788320243358612,
|
| 315 |
+
"threshold": 0.16788320243358612,
|
| 316 |
+
"recall_at_prec_30": 0.45161290322580644,
|
| 317 |
+
"recall_at_prec_50": 0.22580645161290322,
|
| 318 |
+
"recall_at_prec_70": 0.0,
|
| 319 |
+
"cutoff": 4.0,
|
| 320 |
+
"n_events": 1890
|
| 321 |
+
},
|
| 322 |
+
{
|
| 323 |
+
"auc_pr": 0.23877494536053875,
|
| 324 |
+
"auc_roc": 0.867622825755264,
|
| 325 |
+
"f1_at_50": 0.0625,
|
| 326 |
+
"n_positive": 58,
|
| 327 |
+
"n_total": 1753,
|
| 328 |
+
"pos_rate": 0.03308613804905876,
|
| 329 |
+
"f1": 0.33082706275086216,
|
| 330 |
+
"optimal_threshold": 0.21164827048778534,
|
| 331 |
+
"threshold": 0.21164827048778534,
|
| 332 |
+
"recall_at_prec_30": 0.3103448275862069,
|
| 333 |
+
"recall_at_prec_50": 0.1896551724137931,
|
| 334 |
+
"recall_at_prec_70": 0.0,
|
| 335 |
+
"cutoff": 5.0,
|
| 336 |
+
"n_events": 1753
|
| 337 |
+
},
|
| 338 |
+
{
|
| 339 |
+
"auc_pr": 0.1838323482889146,
|
| 340 |
+
"auc_roc": 0.8097419204836084,
|
| 341 |
+
"f1_at_50": 0.06666666666666667,
|
| 342 |
+
"n_positive": 55,
|
| 343 |
+
"n_total": 1619,
|
| 344 |
+
"pos_rate": 0.033971587399629403,
|
| 345 |
+
"f1": 0.2741935434508325,
|
| 346 |
+
"optimal_threshold": 0.21228547394275665,
|
| 347 |
+
"threshold": 0.21228547394275665,
|
| 348 |
+
"recall_at_prec_30": 0.18181818181818182,
|
| 349 |
+
"recall_at_prec_50": 0.07272727272727272,
|
| 350 |
+
"recall_at_prec_70": 0.0,
|
| 351 |
+
"cutoff": 6.0,
|
| 352 |
+
"n_events": 1619
|
| 353 |
+
}
|
| 354 |
+
]
|
| 355 |
+
}
|
src/__init__.py
ADDED
|
File without changes
|
src/data/__init__.py
ADDED
|
File without changes
|
src/data/augment.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code -- 2026-02-08
|
| 2 |
+
"""Data augmentation for the conjunction prediction dataset.
|
| 3 |
+
|
| 4 |
+
The fundamental problem: only 67 high-risk events out of 13,154 in training (0.5%).
|
| 5 |
+
This module provides two augmentation strategies:
|
| 6 |
+
|
| 7 |
+
1. SPACE-TRACK INTEGRATION: Merge real high-risk CDMs from Space-Track's cdm_public
|
| 8 |
+
feed. These have fewer features (16 vs 103) but provide real positive examples.
|
| 9 |
+
|
| 10 |
+
2. TIME-SERIES AUGMENTATION: Create synthetic variants of existing high-risk events
|
| 11 |
+
by applying realistic perturbations:
|
| 12 |
+
- Gaussian noise on covariance/position/velocity features
|
| 13 |
+
- Temporal jittering (shift CDM creation times slightly)
|
| 14 |
+
- Feature dropout (randomly zero out some features, simulating missing data)
|
| 15 |
+
- Sequence truncation (remove early CDMs, simulating late detection)
|
| 16 |
+
|
| 17 |
+
Both strategies are physics-aware: they don't generate impossible configurations
|
| 18 |
+
(e.g., negative miss distances or covariance values).
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import pandas as pd
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def augment_event_noise(
|
| 27 |
+
event_df: pd.DataFrame,
|
| 28 |
+
noise_scale: float = 0.05,
|
| 29 |
+
n_augments: int = 5,
|
| 30 |
+
rng: np.random.Generator = None,
|
| 31 |
+
) -> list[pd.DataFrame]:
|
| 32 |
+
"""
|
| 33 |
+
Create n_augments noisy variants of a single conjunction event.
|
| 34 |
+
|
| 35 |
+
Applies Gaussian noise to numeric features, scaled by each column's
|
| 36 |
+
standard deviation within the event. Preserves event_id structure and
|
| 37 |
+
ensures physical constraints (non-negative distances, etc.).
|
| 38 |
+
"""
|
| 39 |
+
if rng is None:
|
| 40 |
+
rng = np.random.default_rng(42)
|
| 41 |
+
|
| 42 |
+
# Identify numeric columns to perturb (exclude IDs and targets)
|
| 43 |
+
exclude = {"event_id", "time_to_tca", "risk", "mission_id", "source"}
|
| 44 |
+
numeric_cols = event_df.select_dtypes(include=[np.number]).columns
|
| 45 |
+
perturb_cols = [c for c in numeric_cols if c not in exclude]
|
| 46 |
+
|
| 47 |
+
augmented = []
|
| 48 |
+
for i in range(n_augments):
|
| 49 |
+
aug = event_df.copy()
|
| 50 |
+
|
| 51 |
+
for col in perturb_cols:
|
| 52 |
+
values = aug[col].values.astype(float)
|
| 53 |
+
col_std = np.std(values)
|
| 54 |
+
if col_std < 1e-10:
|
| 55 |
+
col_std = np.abs(np.mean(values)) * 0.01 + 1e-10
|
| 56 |
+
|
| 57 |
+
noise = rng.normal(0, noise_scale * col_std, size=len(values))
|
| 58 |
+
aug[col] = values + noise
|
| 59 |
+
|
| 60 |
+
# Physical constraints
|
| 61 |
+
if "miss_distance" in aug.columns:
|
| 62 |
+
aug["miss_distance"] = aug["miss_distance"].clip(lower=0)
|
| 63 |
+
if "relative_speed" in aug.columns:
|
| 64 |
+
aug["relative_speed"] = aug["relative_speed"].clip(lower=0)
|
| 65 |
+
|
| 66 |
+
# Ensure covariance sigma columns stay positive
|
| 67 |
+
sigma_cols = [c for c in perturb_cols if "sigma" in c.lower()]
|
| 68 |
+
for col in sigma_cols:
|
| 69 |
+
aug[col] = aug[col].clip(lower=0)
|
| 70 |
+
|
| 71 |
+
augmented.append(aug)
|
| 72 |
+
|
| 73 |
+
return augmented
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def augment_event_truncate(
|
| 77 |
+
event_df: pd.DataFrame,
|
| 78 |
+
min_keep: int = 3,
|
| 79 |
+
n_augments: int = 3,
|
| 80 |
+
rng: np.random.Generator = None,
|
| 81 |
+
) -> list[pd.DataFrame]:
|
| 82 |
+
"""
|
| 83 |
+
Create truncated variants by removing early CDMs.
|
| 84 |
+
|
| 85 |
+
Simulates late-detection scenarios where only the most recent CDMs
|
| 86 |
+
are available (closer to TCA).
|
| 87 |
+
"""
|
| 88 |
+
if rng is None:
|
| 89 |
+
rng = np.random.default_rng(42)
|
| 90 |
+
|
| 91 |
+
# Sort by time_to_tca descending (first CDM = furthest from TCA)
|
| 92 |
+
event_df = event_df.sort_values("time_to_tca", ascending=False)
|
| 93 |
+
n_cdms = len(event_df)
|
| 94 |
+
|
| 95 |
+
if n_cdms <= min_keep:
|
| 96 |
+
return []
|
| 97 |
+
|
| 98 |
+
augmented = []
|
| 99 |
+
for _ in range(n_augments):
|
| 100 |
+
# Keep between min_keep and n_cdms-1 CDMs (always keep the last few)
|
| 101 |
+
n_keep = rng.integers(min_keep, n_cdms)
|
| 102 |
+
aug = event_df.iloc[-n_keep:].copy()
|
| 103 |
+
augmented.append(aug)
|
| 104 |
+
|
| 105 |
+
return augmented
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def augment_positive_events(
|
| 109 |
+
df: pd.DataFrame,
|
| 110 |
+
target_ratio: float = 0.05,
|
| 111 |
+
noise_scale: float = 0.05,
|
| 112 |
+
seed: int = 42,
|
| 113 |
+
) -> pd.DataFrame:
|
| 114 |
+
"""
|
| 115 |
+
Augment the positive (high-risk) class to reach target_ratio.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
df: full training DataFrame with event_id, risk columns
|
| 119 |
+
target_ratio: desired fraction of high-risk events (default 5%)
|
| 120 |
+
noise_scale: std dev of Gaussian noise as fraction of feature std
|
| 121 |
+
seed: random seed
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Augmented DataFrame with new synthetic positive events appended
|
| 125 |
+
"""
|
| 126 |
+
rng = np.random.default_rng(seed)
|
| 127 |
+
|
| 128 |
+
# Find positive events
|
| 129 |
+
event_risks = df.groupby("event_id")["risk"].last()
|
| 130 |
+
pos_event_ids = event_risks[event_risks > -5].index.tolist()
|
| 131 |
+
neg_event_ids = event_risks[event_risks <= -5].index.tolist()
|
| 132 |
+
|
| 133 |
+
n_pos = len(pos_event_ids)
|
| 134 |
+
n_neg = len(neg_event_ids)
|
| 135 |
+
n_total = n_pos + n_neg
|
| 136 |
+
|
| 137 |
+
# How many positive events do we need?
|
| 138 |
+
target_pos = int(target_ratio * (n_total / (1 - target_ratio)))
|
| 139 |
+
n_needed = max(0, target_pos - n_pos)
|
| 140 |
+
|
| 141 |
+
if n_needed == 0:
|
| 142 |
+
print(f"Already at target ratio ({n_pos}/{n_total} = {n_pos/n_total:.1%})")
|
| 143 |
+
return df
|
| 144 |
+
|
| 145 |
+
print(f"Augmenting: {n_pos} positive events → {n_pos + n_needed} "
|
| 146 |
+
f"(target {target_ratio:.0%} of {n_total + n_needed})")
|
| 147 |
+
|
| 148 |
+
# Generate augmented events
|
| 149 |
+
max_event_id = df["event_id"].max()
|
| 150 |
+
augmented_dfs = []
|
| 151 |
+
generated = 0
|
| 152 |
+
|
| 153 |
+
while generated < n_needed:
|
| 154 |
+
# Pick a random positive event to augment
|
| 155 |
+
src_event_id = rng.choice(pos_event_ids)
|
| 156 |
+
src_event = df[df["event_id"] == src_event_id]
|
| 157 |
+
|
| 158 |
+
# Apply noise augmentation
|
| 159 |
+
aug_variants = augment_event_noise(
|
| 160 |
+
src_event, noise_scale=noise_scale, n_augments=1, rng=rng
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Also try truncation sometimes
|
| 164 |
+
if rng.random() < 0.3 and len(src_event) > 3:
|
| 165 |
+
trunc_variants = augment_event_truncate(
|
| 166 |
+
src_event, n_augments=1, rng=rng
|
| 167 |
+
)
|
| 168 |
+
aug_variants.extend(trunc_variants)
|
| 169 |
+
|
| 170 |
+
for aug_df in aug_variants:
|
| 171 |
+
if generated >= n_needed:
|
| 172 |
+
break
|
| 173 |
+
max_event_id += 1
|
| 174 |
+
aug_df = aug_df.copy()
|
| 175 |
+
aug_df["event_id"] = max_event_id
|
| 176 |
+
aug_df["source"] = "augmented"
|
| 177 |
+
augmented_dfs.append(aug_df)
|
| 178 |
+
generated += 1
|
| 179 |
+
|
| 180 |
+
if augmented_dfs:
|
| 181 |
+
augmented = pd.concat(augmented_dfs, ignore_index=True)
|
| 182 |
+
result = pd.concat([df, augmented], ignore_index=True)
|
| 183 |
+
|
| 184 |
+
# Verify
|
| 185 |
+
event_risks = result.groupby("event_id")["risk"].last()
|
| 186 |
+
new_pos = (event_risks > -5).sum()
|
| 187 |
+
new_total = len(event_risks)
|
| 188 |
+
print(f"Result: {new_pos} positive / {new_total} total "
|
| 189 |
+
f"({new_pos/new_total:.1%})")
|
| 190 |
+
return result
|
| 191 |
+
|
| 192 |
+
return df
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def integrate_spacetrack_positives(
|
| 196 |
+
kelvins_df: pd.DataFrame,
|
| 197 |
+
spacetrack_path: Path,
|
| 198 |
+
) -> pd.DataFrame:
|
| 199 |
+
"""
|
| 200 |
+
Add Space-Track emergency CDMs as additional positive training examples.
|
| 201 |
+
|
| 202 |
+
Since Space-Track cdm_public has only 16 features vs Kelvins' 103,
|
| 203 |
+
missing features are filled with 0. The model will learn to use whatever
|
| 204 |
+
features are available.
|
| 205 |
+
"""
|
| 206 |
+
if not spacetrack_path.exists():
|
| 207 |
+
print(f"No Space-Track data at {spacetrack_path}")
|
| 208 |
+
return kelvins_df
|
| 209 |
+
|
| 210 |
+
from src.data.merge_sources import (
|
| 211 |
+
load_spacetrack_cdms, group_into_events, merge_datasets
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
st_df = load_spacetrack_cdms(spacetrack_path)
|
| 215 |
+
st_df = group_into_events(st_df)
|
| 216 |
+
|
| 217 |
+
merged = merge_datasets(kelvins_df, st_df)
|
| 218 |
+
return merged
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def build_augmented_training_set(
|
| 222 |
+
data_dir: Path,
|
| 223 |
+
target_positive_ratio: float = 0.05,
|
| 224 |
+
noise_scale: float = 0.05,
|
| 225 |
+
seed: int = 42,
|
| 226 |
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 227 |
+
"""
|
| 228 |
+
Build the full augmented training set from all available sources.
|
| 229 |
+
|
| 230 |
+
Steps:
|
| 231 |
+
1. Load ESA Kelvins train/test
|
| 232 |
+
2. Merge Space-Track emergency CDMs into training set
|
| 233 |
+
3. Apply time-series augmentation to positive events
|
| 234 |
+
4. Return (augmented_train, original_test)
|
| 235 |
+
|
| 236 |
+
Test set is NEVER augmented — it stays as Kelvins-only for fair evaluation.
|
| 237 |
+
"""
|
| 238 |
+
from src.data.cdm_loader import load_dataset
|
| 239 |
+
|
| 240 |
+
print("=" * 60)
|
| 241 |
+
print(" Building Augmented Training Set")
|
| 242 |
+
print("=" * 60)
|
| 243 |
+
|
| 244 |
+
# Step 1: Load Kelvins
|
| 245 |
+
print("\n1. Loading ESA Kelvins dataset ...")
|
| 246 |
+
train_df, test_df = load_dataset(data_dir / "cdm")
|
| 247 |
+
|
| 248 |
+
# Defragment and tag source
|
| 249 |
+
train_df = train_df.copy()
|
| 250 |
+
test_df = test_df.copy()
|
| 251 |
+
train_df["source"] = "kelvins"
|
| 252 |
+
test_df["source"] = "kelvins"
|
| 253 |
+
|
| 254 |
+
# Count initial positives
|
| 255 |
+
event_risks = train_df.groupby("event_id")["risk"].last()
|
| 256 |
+
n_pos_initial = (event_risks > -5).sum()
|
| 257 |
+
n_total_initial = len(event_risks)
|
| 258 |
+
print(f" Initial: {n_pos_initial} positive / {n_total_initial} total "
|
| 259 |
+
f"({n_pos_initial/n_total_initial:.2%})")
|
| 260 |
+
|
| 261 |
+
# Step 2: Space-Track integration
|
| 262 |
+
st_path = data_dir / "cdm_spacetrack" / "cdm_spacetrack_emergency.csv"
|
| 263 |
+
if st_path.exists():
|
| 264 |
+
print(f"\n2. Integrating Space-Track emergency CDMs ...")
|
| 265 |
+
train_df = integrate_spacetrack_positives(train_df, st_path)
|
| 266 |
+
else:
|
| 267 |
+
print(f"\n2. No Space-Track data found (skipping)")
|
| 268 |
+
|
| 269 |
+
# Step 3: Time-series augmentation
|
| 270 |
+
print(f"\n3. Augmenting positive events (target ratio: {target_positive_ratio:.0%}) ...")
|
| 271 |
+
train_df = augment_positive_events(
|
| 272 |
+
train_df,
|
| 273 |
+
target_ratio=target_positive_ratio,
|
| 274 |
+
noise_scale=noise_scale,
|
| 275 |
+
seed=seed,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# Final stats
|
| 279 |
+
event_risks = train_df.groupby("event_id")["risk"].last()
|
| 280 |
+
event_sources = train_df.groupby("event_id")["source"].first()
|
| 281 |
+
n_kelvins = (event_sources == "kelvins").sum()
|
| 282 |
+
n_spacetrack = (event_sources == "spacetrack").sum()
|
| 283 |
+
n_augmented = (event_sources == "augmented").sum()
|
| 284 |
+
n_pos_final = (event_risks > -5).sum()
|
| 285 |
+
n_total_final = len(event_risks)
|
| 286 |
+
|
| 287 |
+
print(f"\n{'=' * 60}")
|
| 288 |
+
print(f" Final Training Set:")
|
| 289 |
+
print(f" Kelvins events: {n_kelvins}")
|
| 290 |
+
print(f" Space-Track events: {n_spacetrack}")
|
| 291 |
+
print(f" Augmented events: {n_augmented}")
|
| 292 |
+
print(f" Total events: {n_total_final}")
|
| 293 |
+
print(f" Positive events: {n_pos_final} ({n_pos_final/n_total_final:.1%})")
|
| 294 |
+
print(f" Total CDM rows: {len(train_df)}")
|
| 295 |
+
print(f"{'=' * 60}")
|
| 296 |
+
|
| 297 |
+
return train_df, test_df
|
src/data/cdm_loader.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code -- 2026-02-08
|
| 2 |
+
"""Load and parse ESA Kelvins CDM dataset into structured formats."""
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import List, Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class CDMSnapshot:
|
| 13 |
+
"""A single Conjunction Data Message update."""
|
| 14 |
+
time_to_tca: float
|
| 15 |
+
miss_distance: float
|
| 16 |
+
relative_speed: float
|
| 17 |
+
risk: float
|
| 18 |
+
features: np.ndarray # all numeric columns as a flat vector
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class ConjunctionEvent:
|
| 23 |
+
"""A complete conjunction event = sequence of CDM snapshots."""
|
| 24 |
+
event_id: int
|
| 25 |
+
cdm_sequence: List[CDMSnapshot] = field(default_factory=list)
|
| 26 |
+
risk_label: int = 0 # 1 if any CDM in sequence has high risk
|
| 27 |
+
final_miss_distance: float = 0.0
|
| 28 |
+
altitude_km: float = 0.0
|
| 29 |
+
object_type: str = ""
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Columns we use for the feature vector (numeric only, excluding IDs/targets)
|
| 33 |
+
EXCLUDE_COLS = {"event_id", "time_to_tca", "risk", "mission_id"}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_cdm_csv(path: Path) -> pd.DataFrame:
|
| 37 |
+
"""Load a CDM CSV and do basic cleaning."""
|
| 38 |
+
df = pd.read_csv(path)
|
| 39 |
+
|
| 40 |
+
# Identify numeric columns for features
|
| 41 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 42 |
+
feature_cols = [c for c in numeric_cols if c not in EXCLUDE_COLS]
|
| 43 |
+
|
| 44 |
+
# Fill NaN with 0 for numeric features (some covariance cols are sparse)
|
| 45 |
+
df[feature_cols] = df[feature_cols].fillna(0)
|
| 46 |
+
|
| 47 |
+
return df
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_dataset(data_dir: Path) -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 51 |
+
"""Load train and test CDM DataFrames."""
|
| 52 |
+
# Find the CSV files (may be in subdirectory after extraction)
|
| 53 |
+
train_candidates = list(data_dir.rglob("*train*.csv"))
|
| 54 |
+
test_candidates = list(data_dir.rglob("*test*.csv"))
|
| 55 |
+
|
| 56 |
+
if not train_candidates:
|
| 57 |
+
raise FileNotFoundError(f"No train CSV found in {data_dir}")
|
| 58 |
+
if not test_candidates:
|
| 59 |
+
raise FileNotFoundError(f"No test CSV found in {data_dir}")
|
| 60 |
+
|
| 61 |
+
train_path = train_candidates[0]
|
| 62 |
+
test_path = test_candidates[0]
|
| 63 |
+
|
| 64 |
+
print(f"Loading train: {train_path}")
|
| 65 |
+
print(f"Loading test: {test_path}")
|
| 66 |
+
|
| 67 |
+
train_df = load_cdm_csv(train_path)
|
| 68 |
+
test_df = load_cdm_csv(test_path)
|
| 69 |
+
|
| 70 |
+
print(f"Train: {len(train_df)} rows, {train_df['event_id'].nunique()} events")
|
| 71 |
+
print(f"Test: {len(test_df)} rows, {test_df['event_id'].nunique()} events")
|
| 72 |
+
|
| 73 |
+
return train_df, test_df
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_feature_columns(df: pd.DataFrame) -> list[str]:
|
| 77 |
+
"""Get the list of numeric feature columns (excluding IDs and targets)."""
|
| 78 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 79 |
+
return [c for c in numeric_cols if c not in EXCLUDE_COLS]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def build_events(df: pd.DataFrame, feature_cols: list[str] = None) -> list[ConjunctionEvent]:
|
| 83 |
+
"""Group CDM rows by event_id into ConjunctionEvent objects (vectorized).
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
df: CDM DataFrame
|
| 87 |
+
feature_cols: optional fixed list of feature columns (for train/test consistency)
|
| 88 |
+
"""
|
| 89 |
+
if feature_cols is None:
|
| 90 |
+
feature_cols = get_feature_columns(df)
|
| 91 |
+
else:
|
| 92 |
+
# Ensure all requested columns exist; fill missing with 0
|
| 93 |
+
for col in feature_cols:
|
| 94 |
+
if col not in df.columns:
|
| 95 |
+
df = df.copy()
|
| 96 |
+
df[col] = 0.0
|
| 97 |
+
events = []
|
| 98 |
+
|
| 99 |
+
# Pre-extract feature matrix as float64 (avoids per-row pandas indexing)
|
| 100 |
+
feature_matrix = df[feature_cols].values # (N, F) float64
|
| 101 |
+
feature_matrix = np.nan_to_num(feature_matrix, nan=0.0, posinf=0.0, neginf=0.0)
|
| 102 |
+
|
| 103 |
+
# Sort entire dataframe by event_id then time_to_tca descending
|
| 104 |
+
df = df.copy()
|
| 105 |
+
df["_row_idx"] = np.arange(len(df))
|
| 106 |
+
df = df.sort_values(["event_id", "time_to_tca"], ascending=[True, False])
|
| 107 |
+
|
| 108 |
+
# Determine altitude column
|
| 109 |
+
alt_col = None
|
| 110 |
+
for col in ["t_h_apo", "c_h_apo"]:
|
| 111 |
+
if col in df.columns:
|
| 112 |
+
alt_col = col
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
has_miss = "miss_distance" in df.columns
|
| 116 |
+
has_speed = "relative_speed" in df.columns
|
| 117 |
+
has_risk = "risk" in df.columns
|
| 118 |
+
has_obj_type = "c_object_type" in df.columns
|
| 119 |
+
|
| 120 |
+
for event_id, group in df.groupby("event_id", sort=True):
|
| 121 |
+
row_indices = group["_row_idx"].values
|
| 122 |
+
|
| 123 |
+
# Build CDM sequence using pre-extracted arrays
|
| 124 |
+
cdm_seq = []
|
| 125 |
+
for ridx in row_indices:
|
| 126 |
+
snap = CDMSnapshot(
|
| 127 |
+
time_to_tca=float(df.iloc[ridx]["time_to_tca"]) if "time_to_tca" in df.columns else 0.0,
|
| 128 |
+
miss_distance=float(df.iloc[ridx]["miss_distance"]) if has_miss else 0.0,
|
| 129 |
+
relative_speed=float(df.iloc[ridx]["relative_speed"]) if has_speed else 0.0,
|
| 130 |
+
risk=float(df.iloc[ridx]["risk"]) if has_risk else 0.0,
|
| 131 |
+
features=feature_matrix[ridx].astype(np.float32),
|
| 132 |
+
)
|
| 133 |
+
cdm_seq.append(snap)
|
| 134 |
+
|
| 135 |
+
final_cdm = cdm_seq[-1]
|
| 136 |
+
risk_label = 1 if final_cdm.risk > -5 else 0
|
| 137 |
+
alt = float(group[alt_col].iloc[-1]) if alt_col else 0.0
|
| 138 |
+
obj_type = str(group["c_object_type"].iloc[0]) if has_obj_type else "unknown"
|
| 139 |
+
|
| 140 |
+
events.append(ConjunctionEvent(
|
| 141 |
+
event_id=int(event_id),
|
| 142 |
+
cdm_sequence=cdm_seq,
|
| 143 |
+
risk_label=risk_label,
|
| 144 |
+
final_miss_distance=final_cdm.miss_distance,
|
| 145 |
+
altitude_km=alt,
|
| 146 |
+
object_type=obj_type,
|
| 147 |
+
))
|
| 148 |
+
|
| 149 |
+
n_high = sum(e.risk_label for e in events)
|
| 150 |
+
print(f"Built {len(events)} events, {n_high} high-risk ({100*n_high/len(events):.1f}%)")
|
| 151 |
+
return events
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def events_to_flat_features(events: list[ConjunctionEvent]) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 155 |
+
"""
|
| 156 |
+
Extract flat feature vectors from events for classical ML.
|
| 157 |
+
Uses the LAST CDM snapshot (closest to TCA) + temporal trend features.
|
| 158 |
+
|
| 159 |
+
Returns: (X, y_risk, y_miss)
|
| 160 |
+
"""
|
| 161 |
+
X_list = []
|
| 162 |
+
y_risk = []
|
| 163 |
+
y_miss = []
|
| 164 |
+
|
| 165 |
+
for event in events:
|
| 166 |
+
seq = event.cdm_sequence
|
| 167 |
+
last = seq[-1]
|
| 168 |
+
base = last.features.copy()
|
| 169 |
+
|
| 170 |
+
miss_values = np.array([s.miss_distance for s in seq])
|
| 171 |
+
risk_values = np.array([s.risk for s in seq])
|
| 172 |
+
tca_values = np.array([s.time_to_tca for s in seq])
|
| 173 |
+
|
| 174 |
+
n_cdms = len(seq)
|
| 175 |
+
miss_mean = float(np.mean(miss_values)) if n_cdms > 0 else 0.0
|
| 176 |
+
miss_std = float(np.std(miss_values)) if n_cdms > 1 else 0.0
|
| 177 |
+
|
| 178 |
+
miss_trend = 0.0
|
| 179 |
+
if n_cdms > 1 and np.std(tca_values) > 0:
|
| 180 |
+
miss_trend = float(np.polyfit(tca_values, miss_values, 1)[0])
|
| 181 |
+
|
| 182 |
+
risk_trend = 0.0
|
| 183 |
+
if n_cdms > 1 and np.std(tca_values) > 0:
|
| 184 |
+
risk_trend = float(np.polyfit(tca_values, risk_values, 1)[0])
|
| 185 |
+
|
| 186 |
+
temporal_feats = np.array([
|
| 187 |
+
n_cdms,
|
| 188 |
+
miss_mean,
|
| 189 |
+
miss_std,
|
| 190 |
+
miss_trend,
|
| 191 |
+
risk_trend,
|
| 192 |
+
float(miss_values[0] - miss_values[-1]) if n_cdms > 1 else 0.0,
|
| 193 |
+
last.time_to_tca,
|
| 194 |
+
last.relative_speed,
|
| 195 |
+
], dtype=np.float32)
|
| 196 |
+
|
| 197 |
+
combined = np.concatenate([base, temporal_feats])
|
| 198 |
+
X_list.append(combined)
|
| 199 |
+
y_risk.append(event.risk_label)
|
| 200 |
+
y_miss.append(np.log1p(max(event.final_miss_distance, 0.0)))
|
| 201 |
+
|
| 202 |
+
X = np.stack(X_list)
|
| 203 |
+
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
|
| 204 |
+
|
| 205 |
+
return X, np.array(y_risk), np.array(y_miss)
|
src/data/counterfactual.py
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SGP4 counterfactual propagation — "what if no maneuver?" simulation.
|
| 2 |
+
|
| 3 |
+
For each likely-avoidance maneuver, propagates the pre-maneuver TLE forward
|
| 4 |
+
to estimate whether a close approach would have occurred. This generates
|
| 5 |
+
counterfactual "would-have-collided" labels for training enrichment.
|
| 6 |
+
|
| 7 |
+
Uses the sgp4 library for efficient satellite propagation.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import numpy as np
|
| 12 |
+
from datetime import datetime, timedelta, timezone
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from sgp4.api import Satrec, WGS72
|
| 16 |
+
from sgp4 import exporter
|
| 17 |
+
SGP4_AVAILABLE = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
SGP4_AVAILABLE = False
|
| 20 |
+
|
| 21 |
+
# Earth parameters
|
| 22 |
+
EARTH_RADIUS_KM = 6378.137
|
| 23 |
+
|
| 24 |
+
# Counterfactual thresholds
|
| 25 |
+
COLLISION_THRESHOLD_KM = 1.0 # "Would have collided" if closer than this
|
| 26 |
+
NEARBY_ALT_BAND_KM = 50.0 # Altitude proximity for neighbor selection
|
| 27 |
+
NEARBY_RAAN_BAND_DEG = 30.0 # RAAN proximity for neighbor selection
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def celestrak_json_to_satrec(tle_json: dict) -> "Satrec":
|
| 31 |
+
"""Convert a CelesTrak GP JSON record to an sgp4 Satrec object.
|
| 32 |
+
|
| 33 |
+
CelesTrak JSON includes TLE_LINE1/TLE_LINE2 when available. Falls
|
| 34 |
+
back to constructing from orbital elements via sgp4init().
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
tle_json: CelesTrak GP JSON dict with orbital elements.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
sgp4 Satrec object ready for propagation.
|
| 41 |
+
|
| 42 |
+
Raises:
|
| 43 |
+
ImportError: If sgp4 is not installed.
|
| 44 |
+
ValueError: If TLE data is insufficient.
|
| 45 |
+
"""
|
| 46 |
+
if not SGP4_AVAILABLE:
|
| 47 |
+
raise ImportError("sgp4 library is required: pip install sgp4")
|
| 48 |
+
|
| 49 |
+
# Prefer TLE lines if available (most reliable)
|
| 50 |
+
line1 = tle_json.get("TLE_LINE1", "")
|
| 51 |
+
line2 = tle_json.get("TLE_LINE2", "")
|
| 52 |
+
if line1 and line2:
|
| 53 |
+
return Satrec.twoline2rv(line1, line2)
|
| 54 |
+
|
| 55 |
+
# Construct from JSON orbital elements using sgp4init
|
| 56 |
+
satrec = Satrec()
|
| 57 |
+
|
| 58 |
+
# Parse epoch
|
| 59 |
+
epoch_str = tle_json.get("EPOCH", "")
|
| 60 |
+
if not epoch_str:
|
| 61 |
+
raise ValueError("No EPOCH in TLE JSON")
|
| 62 |
+
|
| 63 |
+
epoch_dt = datetime.fromisoformat(epoch_str.replace("Z", "+00:00"))
|
| 64 |
+
if epoch_dt.tzinfo is None:
|
| 65 |
+
epoch_dt = epoch_dt.replace(tzinfo=timezone.utc)
|
| 66 |
+
|
| 67 |
+
# Convert to Julian date pair for sgp4
|
| 68 |
+
year = epoch_dt.year
|
| 69 |
+
mon = epoch_dt.month
|
| 70 |
+
day = epoch_dt.day
|
| 71 |
+
hr = epoch_dt.hour
|
| 72 |
+
minute = epoch_dt.minute
|
| 73 |
+
sec = epoch_dt.second + epoch_dt.microsecond / 1e6
|
| 74 |
+
|
| 75 |
+
# sgp4init expects elements in specific units
|
| 76 |
+
no_kozai = float(tle_json.get("MEAN_MOTION", 0)) * (2.0 * math.pi / 1440.0) # rev/day -> rad/min
|
| 77 |
+
ecco = float(tle_json.get("ECCENTRICITY", 0))
|
| 78 |
+
inclo = math.radians(float(tle_json.get("INCLINATION", 0)))
|
| 79 |
+
nodeo = math.radians(float(tle_json.get("RA_OF_ASC_NODE", 0)))
|
| 80 |
+
argpo = math.radians(float(tle_json.get("ARG_OF_PERICENTER", 0)))
|
| 81 |
+
mo = math.radians(float(tle_json.get("MEAN_ANOMALY", 0)))
|
| 82 |
+
bstar = float(tle_json.get("BSTAR", 0))
|
| 83 |
+
norad_id = int(tle_json.get("NORAD_CAT_ID", 0))
|
| 84 |
+
|
| 85 |
+
# Epoch in Julian date
|
| 86 |
+
jd_base = _datetime_to_jd(epoch_dt)
|
| 87 |
+
epoch_jd = jd_base
|
| 88 |
+
# sgp4init epoch is minutes since 1949-12-31 00:00 UTC
|
| 89 |
+
# But the Python API uses (jdsatepoch, jdsatepochF) pair
|
| 90 |
+
jd_whole = int(epoch_jd)
|
| 91 |
+
jd_frac = epoch_jd - jd_whole
|
| 92 |
+
|
| 93 |
+
satrec.sgp4init(
|
| 94 |
+
WGS72, # gravity model
|
| 95 |
+
'i', # 'a' = old AFSPC mode, 'i' = improved
|
| 96 |
+
norad_id, # NORAD catalog number
|
| 97 |
+
(epoch_jd - 2433281.5), # epoch in days since 1949 Dec 31 00:00 UT
|
| 98 |
+
bstar, # BSTAR drag term
|
| 99 |
+
0.0, # ndot (not used in sgp4init 'i' mode)
|
| 100 |
+
0.0, # nddot (not used)
|
| 101 |
+
ecco, # eccentricity
|
| 102 |
+
argpo, # argument of perigee (radians)
|
| 103 |
+
inclo, # inclination (radians)
|
| 104 |
+
mo, # mean anomaly (radians)
|
| 105 |
+
no_kozai, # mean motion (radians/minute)
|
| 106 |
+
nodeo, # RAAN (radians)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return satrec
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _datetime_to_jd(dt: datetime) -> float:
|
| 113 |
+
"""Convert datetime to Julian Date."""
|
| 114 |
+
if dt.tzinfo is not None:
|
| 115 |
+
dt = dt.astimezone(timezone.utc).replace(tzinfo=None)
|
| 116 |
+
a = (14 - dt.month) // 12
|
| 117 |
+
y = dt.year + 4800 - a
|
| 118 |
+
m = dt.month + 12 * a - 3
|
| 119 |
+
jdn = dt.day + (153 * m + 2) // 5 + 365 * y + y // 4 - y // 100 + y // 400 - 32045
|
| 120 |
+
jd = jdn + (dt.hour - 12) / 24.0 + dt.minute / 1440.0 + dt.second / 86400.0
|
| 121 |
+
return jd
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _propagate_positions(satrec: "Satrec", start_jd: float, hours: float, step_min: float) -> np.ndarray:
|
| 125 |
+
"""Propagate a satellite and return position array (N x 3) in km.
|
| 126 |
+
|
| 127 |
+
Returns empty array if propagation fails.
|
| 128 |
+
"""
|
| 129 |
+
n_steps = int(hours * 60 / step_min) + 1
|
| 130 |
+
positions = []
|
| 131 |
+
|
| 132 |
+
for i in range(n_steps):
|
| 133 |
+
minutes_since_epoch = (start_jd - satrec.jdsatepoch - satrec.jdsatepochF) * 1440.0 + i * step_min
|
| 134 |
+
e, r, v = satrec.sgp4(satrec.jdsatepoch, satrec.jdsatepochF + minutes_since_epoch / 1440.0)
|
| 135 |
+
if e != 0:
|
| 136 |
+
continue
|
| 137 |
+
positions.append(r)
|
| 138 |
+
|
| 139 |
+
if not positions:
|
| 140 |
+
return np.array([]).reshape(0, 3)
|
| 141 |
+
return np.array(positions)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def find_nearby_satellites(
|
| 145 |
+
maneuvered_tle: dict,
|
| 146 |
+
all_tles: list[dict],
|
| 147 |
+
alt_band_km: float = NEARBY_ALT_BAND_KM,
|
| 148 |
+
raan_band_deg: float = NEARBY_RAAN_BAND_DEG,
|
| 149 |
+
) -> list[dict]:
|
| 150 |
+
"""Find satellites in similar orbital shell to the maneuvered object."""
|
| 151 |
+
from src.data.maneuver_detector import mean_motion_to_sma, sma_to_altitude
|
| 152 |
+
|
| 153 |
+
norad_id = int(maneuvered_tle.get("NORAD_CAT_ID", 0))
|
| 154 |
+
mm = float(maneuvered_tle.get("MEAN_MOTION", 0))
|
| 155 |
+
target_alt = sma_to_altitude(mean_motion_to_sma(mm))
|
| 156 |
+
target_raan = float(maneuvered_tle.get("RA_OF_ASC_NODE", 0))
|
| 157 |
+
|
| 158 |
+
nearby = []
|
| 159 |
+
for tle in all_tles:
|
| 160 |
+
tid = int(tle.get("NORAD_CAT_ID", 0))
|
| 161 |
+
if tid == norad_id or tid <= 0:
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
t_mm = float(tle.get("MEAN_MOTION", 0))
|
| 165 |
+
t_alt = sma_to_altitude(mean_motion_to_sma(t_mm))
|
| 166 |
+
t_raan = float(tle.get("RA_OF_ASC_NODE", 0))
|
| 167 |
+
|
| 168 |
+
alt_diff = abs(t_alt - target_alt)
|
| 169 |
+
raan_diff = abs(t_raan - target_raan)
|
| 170 |
+
raan_diff = min(raan_diff, 360.0 - raan_diff)
|
| 171 |
+
|
| 172 |
+
if alt_diff < alt_band_km and raan_diff < raan_band_deg:
|
| 173 |
+
nearby.append(tle)
|
| 174 |
+
|
| 175 |
+
return nearby
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def propagate_counterfactual(
|
| 179 |
+
pre_maneuver_tle: dict,
|
| 180 |
+
nearby_tles: list[dict],
|
| 181 |
+
hours_forward: float = 24.0,
|
| 182 |
+
step_minutes: float = 10.0,
|
| 183 |
+
) -> dict:
|
| 184 |
+
"""Simulate "what if no maneuver?" using SGP4 propagation.
|
| 185 |
+
|
| 186 |
+
Propagates the pre-maneuver TLE (before orbit change) forward and
|
| 187 |
+
checks for close approaches with nearby satellites.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
pre_maneuver_tle: Yesterday's TLE for the maneuvered satellite.
|
| 191 |
+
nearby_tles: Current TLEs for nearby satellites.
|
| 192 |
+
hours_forward: How far to propagate (hours).
|
| 193 |
+
step_minutes: Time step for propagation (minutes).
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
Dict with: min_distance_km, time_of_closest_approach,
|
| 197 |
+
would_have_collided, closest_norad_id, n_neighbors_checked.
|
| 198 |
+
"""
|
| 199 |
+
if not SGP4_AVAILABLE:
|
| 200 |
+
return {
|
| 201 |
+
"min_distance_km": None,
|
| 202 |
+
"would_have_collided": False,
|
| 203 |
+
"error": "sgp4 not installed",
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
try:
|
| 207 |
+
target_sat = celestrak_json_to_satrec(pre_maneuver_tle)
|
| 208 |
+
except (ValueError, Exception) as e:
|
| 209 |
+
return {
|
| 210 |
+
"min_distance_km": None,
|
| 211 |
+
"would_have_collided": False,
|
| 212 |
+
"error": f"target TLE parse failed: {e}",
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
# Use current time as propagation start
|
| 216 |
+
now = datetime.now(timezone.utc)
|
| 217 |
+
start_jd = _datetime_to_jd(now)
|
| 218 |
+
|
| 219 |
+
# Propagate maneuvered satellite (pre-maneuver orbit)
|
| 220 |
+
target_positions = _propagate_positions(target_sat, start_jd, hours_forward, step_minutes)
|
| 221 |
+
if len(target_positions) == 0:
|
| 222 |
+
return {
|
| 223 |
+
"min_distance_km": None,
|
| 224 |
+
"would_have_collided": False,
|
| 225 |
+
"error": "target propagation failed",
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
global_min_dist = float("inf")
|
| 229 |
+
closest_norad = 0
|
| 230 |
+
closest_time_offset_min = 0.0
|
| 231 |
+
n_checked = 0
|
| 232 |
+
|
| 233 |
+
for neighbor_tle in nearby_tles:
|
| 234 |
+
try:
|
| 235 |
+
neighbor_sat = celestrak_json_to_satrec(neighbor_tle)
|
| 236 |
+
except (ValueError, Exception):
|
| 237 |
+
continue
|
| 238 |
+
|
| 239 |
+
neighbor_positions = _propagate_positions(neighbor_sat, start_jd, hours_forward, step_minutes)
|
| 240 |
+
if len(neighbor_positions) == 0:
|
| 241 |
+
continue
|
| 242 |
+
|
| 243 |
+
n_checked += 1
|
| 244 |
+
|
| 245 |
+
# Compute distances at each timestep (use min of overlapping steps)
|
| 246 |
+
n_common = min(len(target_positions), len(neighbor_positions))
|
| 247 |
+
diffs = target_positions[:n_common] - neighbor_positions[:n_common]
|
| 248 |
+
distances = np.linalg.norm(diffs, axis=1)
|
| 249 |
+
min_idx = np.argmin(distances)
|
| 250 |
+
min_dist = distances[min_idx]
|
| 251 |
+
|
| 252 |
+
if min_dist < global_min_dist:
|
| 253 |
+
global_min_dist = min_dist
|
| 254 |
+
closest_norad = int(neighbor_tle.get("NORAD_CAT_ID", 0))
|
| 255 |
+
closest_time_offset_min = min_idx * step_minutes
|
| 256 |
+
|
| 257 |
+
if global_min_dist == float("inf"):
|
| 258 |
+
return {
|
| 259 |
+
"min_distance_km": None,
|
| 260 |
+
"would_have_collided": False,
|
| 261 |
+
"n_neighbors_checked": n_checked,
|
| 262 |
+
"error": "no valid neighbors propagated",
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
tca_dt = now + timedelta(minutes=closest_time_offset_min)
|
| 266 |
+
|
| 267 |
+
return {
|
| 268 |
+
"min_distance_km": round(global_min_dist, 3),
|
| 269 |
+
"time_of_closest_approach": tca_dt.isoformat(),
|
| 270 |
+
"would_have_collided": global_min_dist < COLLISION_THRESHOLD_KM,
|
| 271 |
+
"closest_norad_id": closest_norad,
|
| 272 |
+
"n_neighbors_checked": n_checked,
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def compute_forward_trajectory(
|
| 277 |
+
tle_1: dict,
|
| 278 |
+
tle_2: dict,
|
| 279 |
+
hours_forward: float = 120.0,
|
| 280 |
+
step_minutes: float = 20.0,
|
| 281 |
+
) -> list[dict] | None:
|
| 282 |
+
"""Compute full trajectory time series for two satellites.
|
| 283 |
+
|
| 284 |
+
Returns list of trajectory points with ECI positions and separation
|
| 285 |
+
distance, suitable for baking into the webapp alerts JSON so the
|
| 286 |
+
frontend doesn't need to do SGP4 propagation or load TLE data.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
tle_1: CelesTrak GP JSON for satellite 1.
|
| 290 |
+
tle_2: CelesTrak GP JSON for satellite 2.
|
| 291 |
+
hours_forward: How far to propagate (default 120h = 5 days).
|
| 292 |
+
step_minutes: Time step for propagation (minutes).
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
List of dicts with: h (hours from start), d (distance km),
|
| 296 |
+
s1 [x,y,z] ECI km, s2 [x,y,z] ECI km. None if propagation fails.
|
| 297 |
+
"""
|
| 298 |
+
if not SGP4_AVAILABLE:
|
| 299 |
+
return None
|
| 300 |
+
|
| 301 |
+
try:
|
| 302 |
+
sat1 = celestrak_json_to_satrec(tle_1)
|
| 303 |
+
sat2 = celestrak_json_to_satrec(tle_2)
|
| 304 |
+
except (ValueError, Exception):
|
| 305 |
+
return None
|
| 306 |
+
|
| 307 |
+
now = datetime.now(timezone.utc)
|
| 308 |
+
start_jd = _datetime_to_jd(now)
|
| 309 |
+
|
| 310 |
+
n_steps = int(hours_forward * 60 / step_minutes) + 1
|
| 311 |
+
points = []
|
| 312 |
+
|
| 313 |
+
for i in range(n_steps):
|
| 314 |
+
mins = i * step_minutes
|
| 315 |
+
target_jd = start_jd + mins / 1440.0
|
| 316 |
+
jd_whole = int(target_jd)
|
| 317 |
+
jd_frac = target_jd - jd_whole
|
| 318 |
+
|
| 319 |
+
e1, r1, _ = sat1.sgp4(jd_whole, jd_frac)
|
| 320 |
+
e2, r2, _ = sat2.sgp4(jd_whole, jd_frac)
|
| 321 |
+
|
| 322 |
+
if e1 != 0 or e2 != 0:
|
| 323 |
+
continue
|
| 324 |
+
if not all(math.isfinite(v) for v in r1 + r2):
|
| 325 |
+
continue
|
| 326 |
+
|
| 327 |
+
dx = r1[0] - r2[0]
|
| 328 |
+
dy = r1[1] - r2[1]
|
| 329 |
+
dz = r1[2] - r2[2]
|
| 330 |
+
dist = math.sqrt(dx * dx + dy * dy + dz * dz)
|
| 331 |
+
|
| 332 |
+
points.append({
|
| 333 |
+
"h": round(mins / 60.0, 2),
|
| 334 |
+
"d": round(dist, 1),
|
| 335 |
+
"s1": [round(r1[0], 1), round(r1[1], 1), round(r1[2], 1)],
|
| 336 |
+
"s2": [round(r2[0], 1), round(r2[1], 1), round(r2[2], 1)],
|
| 337 |
+
})
|
| 338 |
+
|
| 339 |
+
return points if points else None
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def compute_tca_trail(
|
| 343 |
+
tle_1: dict,
|
| 344 |
+
tle_2: dict,
|
| 345 |
+
tca_hours: float,
|
| 346 |
+
half_window_min: float = 30.0,
|
| 347 |
+
step_minutes: float = 0.25,
|
| 348 |
+
) -> list[dict] | None:
|
| 349 |
+
"""Compute dense trail around TCA for globe orbital path visualization.
|
| 350 |
+
|
| 351 |
+
Returns 15-sec resolution positions for ±30 min around TCA.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
tle_1: CelesTrak GP JSON for satellite 1.
|
| 355 |
+
tle_2: CelesTrak GP JSON for satellite 2.
|
| 356 |
+
tca_hours: Hours from now to TCA (from compute_forward_tca).
|
| 357 |
+
half_window_min: Half window in minutes around TCA.
|
| 358 |
+
step_minutes: Time step in minutes.
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
List of dicts with s1 [x,y,z] and s2 [x,y,z] ECI km. None if fails.
|
| 362 |
+
"""
|
| 363 |
+
if not SGP4_AVAILABLE:
|
| 364 |
+
return None
|
| 365 |
+
|
| 366 |
+
try:
|
| 367 |
+
sat1 = celestrak_json_to_satrec(tle_1)
|
| 368 |
+
sat2 = celestrak_json_to_satrec(tle_2)
|
| 369 |
+
except (ValueError, Exception):
|
| 370 |
+
return None
|
| 371 |
+
|
| 372 |
+
now = datetime.now(timezone.utc)
|
| 373 |
+
start_jd = _datetime_to_jd(now)
|
| 374 |
+
|
| 375 |
+
tca_min = tca_hours * 60.0
|
| 376 |
+
t_start = tca_min - half_window_min
|
| 377 |
+
t_end = tca_min + half_window_min
|
| 378 |
+
n_steps = int((t_end - t_start) / step_minutes) + 1
|
| 379 |
+
|
| 380 |
+
trail = []
|
| 381 |
+
for i in range(n_steps):
|
| 382 |
+
mins = t_start + i * step_minutes
|
| 383 |
+
target_jd = start_jd + mins / 1440.0
|
| 384 |
+
jd_whole = int(target_jd)
|
| 385 |
+
jd_frac = target_jd - jd_whole
|
| 386 |
+
|
| 387 |
+
e1, r1, _ = sat1.sgp4(jd_whole, jd_frac)
|
| 388 |
+
e2, r2, _ = sat2.sgp4(jd_whole, jd_frac)
|
| 389 |
+
|
| 390 |
+
if e1 != 0 or e2 != 0:
|
| 391 |
+
continue
|
| 392 |
+
if not all(math.isfinite(v) for v in r1 + r2):
|
| 393 |
+
continue
|
| 394 |
+
|
| 395 |
+
dx = r1[0] - r2[0]
|
| 396 |
+
dy = r1[1] - r2[1]
|
| 397 |
+
dz = r1[2] - r2[2]
|
| 398 |
+
dist = math.sqrt(dx * dx + dy * dy + dz * dz)
|
| 399 |
+
|
| 400 |
+
trail.append({
|
| 401 |
+
"h": round(mins / 60.0, 3),
|
| 402 |
+
"d": round(dist, 1),
|
| 403 |
+
"s1": [round(r1[0], 1), round(r1[1], 1), round(r1[2], 1)],
|
| 404 |
+
"s2": [round(r2[0], 1), round(r2[1], 1), round(r2[2], 1)],
|
| 405 |
+
})
|
| 406 |
+
|
| 407 |
+
return trail if trail else None
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def compute_forward_tca(
|
| 411 |
+
tle_1: dict,
|
| 412 |
+
tle_2: dict,
|
| 413 |
+
hours_forward: float = 120.0,
|
| 414 |
+
step_minutes: float = 10.0,
|
| 415 |
+
) -> dict:
|
| 416 |
+
"""Compute forward Time of Closest Approach between two satellites.
|
| 417 |
+
|
| 418 |
+
Propagates both satellites forward using SGP4 and finds the minimum
|
| 419 |
+
separation distance and when it occurs.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
tle_1: CelesTrak GP JSON for satellite 1.
|
| 423 |
+
tle_2: CelesTrak GP JSON for satellite 2.
|
| 424 |
+
hours_forward: How far to propagate (default 120h = 5 days).
|
| 425 |
+
step_minutes: Time step for propagation (minutes).
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
Dict with: tca_hours, tca_min_distance_km, or error.
|
| 429 |
+
"""
|
| 430 |
+
if not SGP4_AVAILABLE:
|
| 431 |
+
return {"tca_hours": None, "tca_min_distance_km": None}
|
| 432 |
+
|
| 433 |
+
try:
|
| 434 |
+
sat1 = celestrak_json_to_satrec(tle_1)
|
| 435 |
+
sat2 = celestrak_json_to_satrec(tle_2)
|
| 436 |
+
except (ValueError, Exception) as e:
|
| 437 |
+
return {"tca_hours": None, "tca_min_distance_km": None}
|
| 438 |
+
|
| 439 |
+
now = datetime.now(timezone.utc)
|
| 440 |
+
start_jd = _datetime_to_jd(now)
|
| 441 |
+
|
| 442 |
+
pos1 = _propagate_positions(sat1, start_jd, hours_forward, step_minutes)
|
| 443 |
+
pos2 = _propagate_positions(sat2, start_jd, hours_forward, step_minutes)
|
| 444 |
+
|
| 445 |
+
if len(pos1) == 0 or len(pos2) == 0:
|
| 446 |
+
return {"tca_hours": None, "tca_min_distance_km": None}
|
| 447 |
+
|
| 448 |
+
n_common = min(len(pos1), len(pos2))
|
| 449 |
+
diffs = pos1[:n_common] - pos2[:n_common]
|
| 450 |
+
distances = np.linalg.norm(diffs, axis=1)
|
| 451 |
+
min_idx = int(np.argmin(distances))
|
| 452 |
+
min_dist = float(distances[min_idx])
|
| 453 |
+
tca_hours = min_idx * step_minutes / 60.0
|
| 454 |
+
|
| 455 |
+
return {
|
| 456 |
+
"tca_hours": round(tca_hours, 1),
|
| 457 |
+
"tca_min_distance_km": round(min_dist, 1),
|
| 458 |
+
}
|
src/data/density_features.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code — 2026-02-13
|
| 2 |
+
"""Orbital density features derived from the CRASH Clock framework.
|
| 3 |
+
|
| 4 |
+
Computes population-level orbital density metrics for each conjunction event,
|
| 5 |
+
based on the altitude distribution of all events in the training set.
|
| 6 |
+
|
| 7 |
+
The key insight from Thiele et al. (2025) "An Orbital House of Cards":
|
| 8 |
+
collision rate scales as n² * A_col * v_r — so a conjunction at a crowded
|
| 9 |
+
altitude (550 km Starlink shell) is fundamentally riskier than the same
|
| 10 |
+
miss_distance at a sparse altitude (1200 km).
|
| 11 |
+
|
| 12 |
+
These features are computed from the TRAINING set only and applied to
|
| 13 |
+
validation/test sets to prevent data leakage.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import numpy as np
|
| 18 |
+
import pandas as pd
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
# Physical constants
|
| 22 |
+
EARTH_RADIUS_KM = 6371.0
|
| 23 |
+
GM_M3_S2 = 3.986004418e14 # Earth gravitational parameter (m³/s²)
|
| 24 |
+
|
| 25 |
+
# CRASH Clock cross-sections from Thiele et al. Table (10m-5m-10cm)
|
| 26 |
+
A_COL_SAT_SAT = 300.0 # m² (satellite-satellite, 10m approach)
|
| 27 |
+
A_COL_SAT_DEBRIS = 79.0 # m² (satellite-debris, 5m approach)
|
| 28 |
+
|
| 29 |
+
# Altitude binning
|
| 30 |
+
BIN_WIDTH_KM = 25 # km per altitude bin
|
| 31 |
+
ALT_MIN_KM = 150
|
| 32 |
+
ALT_MAX_KM = 2100
|
| 33 |
+
|
| 34 |
+
# Feature names that will be added to DataFrames
|
| 35 |
+
DENSITY_FEATURES = [
|
| 36 |
+
"shell_density", # events per km³ in altitude bin
|
| 37 |
+
"shell_collision_rate", # Γ from CRASH Clock Eq. 2 (per second)
|
| 38 |
+
"local_crash_clock_log", # log10(seconds to expected collision in shell)
|
| 39 |
+
"altitude_percentile", # CDF position in event altitude distribution
|
| 40 |
+
"n_events_in_shell", # raw count of training events at this altitude
|
| 41 |
+
"shell_risk_rate", # fraction of high-risk events in this altitude bin
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _orbital_speed_kms(altitude_km: float) -> float:
|
| 46 |
+
"""Circular orbital speed in km/s at a given altitude."""
|
| 47 |
+
r_m = (EARTH_RADIUS_KM + altitude_km) * 1000.0
|
| 48 |
+
return np.sqrt(GM_M3_S2 / r_m) / 1000.0 # m/s → km/s
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _mean_relative_speed_kms(altitude_km: float) -> float:
|
| 52 |
+
"""Average relative encounter speed: v_r = (4/3) * v_orbital (Eq. 7)."""
|
| 53 |
+
return (4.0 / 3.0) * _orbital_speed_kms(altitude_km)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _shell_volume_km3(altitude_km: float, width_km: float) -> float:
|
| 57 |
+
"""Volume of a spherical shell at given altitude with given width."""
|
| 58 |
+
r_inner = EARTH_RADIUS_KM + altitude_km - width_km / 2.0
|
| 59 |
+
r_outer = EARTH_RADIUS_KM + altitude_km + width_km / 2.0
|
| 60 |
+
return (4.0 / 3.0) * np.pi * (r_outer**3 - r_inner**3)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class OrbitalDensityComputer:
|
| 64 |
+
"""Computes orbital density features from a training DataFrame.
|
| 65 |
+
|
| 66 |
+
Fit on training data, then transform any DataFrame (train/val/test)
|
| 67 |
+
to add density-based static features per event.
|
| 68 |
+
|
| 69 |
+
The density is computed from event altitudes, NOT from a full TLE
|
| 70 |
+
catalog, so it represents the conjunction density distribution rather
|
| 71 |
+
than the full RSO population. For the Kelvins dataset, this captures
|
| 72 |
+
where conjunction events cluster (which correlates with RSO density).
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self, bin_width_km: float = BIN_WIDTH_KM):
|
| 76 |
+
self.bin_width_km = bin_width_km
|
| 77 |
+
self.bin_edges = np.arange(ALT_MIN_KM, ALT_MAX_KM + bin_width_km, bin_width_km)
|
| 78 |
+
self.bin_centers = (self.bin_edges[:-1] + self.bin_edges[1:]) / 2.0
|
| 79 |
+
self.n_bins = len(self.bin_centers)
|
| 80 |
+
|
| 81 |
+
# Fitted state (populated by fit())
|
| 82 |
+
self.event_counts = None # events per bin
|
| 83 |
+
self.density_per_bin = None # events / km³ per bin
|
| 84 |
+
self.collision_rate = None # Γ per bin (events/s)
|
| 85 |
+
self.crash_clock_log = None # log10(seconds to collision) per bin
|
| 86 |
+
self.risk_rate_per_bin = None # fraction high-risk per bin
|
| 87 |
+
self.altitude_cdf = None # cumulative distribution
|
| 88 |
+
self.is_fitted = False
|
| 89 |
+
|
| 90 |
+
def _event_altitude(self, df: pd.DataFrame) -> np.ndarray:
|
| 91 |
+
"""Compute conjunction altitude for each event (last CDM row).
|
| 92 |
+
|
| 93 |
+
Uses mean of target and chaser perigee altitudes as the approximate
|
| 94 |
+
conjunction altitude. Falls back to semi-major axis minus Earth radius.
|
| 95 |
+
"""
|
| 96 |
+
event_df = df.groupby("event_id").last()
|
| 97 |
+
|
| 98 |
+
# Primary: mean of perigee altitudes (where most conjunctions happen)
|
| 99 |
+
t_alt = np.zeros(len(event_df))
|
| 100 |
+
c_alt = np.zeros(len(event_df))
|
| 101 |
+
|
| 102 |
+
if "t_h_per" in event_df.columns:
|
| 103 |
+
t_alt = event_df["t_h_per"].fillna(0).values
|
| 104 |
+
elif "t_j2k_sma" in event_df.columns:
|
| 105 |
+
t_alt = event_df["t_j2k_sma"].fillna(EARTH_RADIUS_KM).values - EARTH_RADIUS_KM
|
| 106 |
+
|
| 107 |
+
if "c_h_per" in event_df.columns:
|
| 108 |
+
c_alt = event_df["c_h_per"].fillna(0).values
|
| 109 |
+
elif "c_j2k_sma" in event_df.columns:
|
| 110 |
+
c_alt = event_df["c_j2k_sma"].fillna(EARTH_RADIUS_KM).values - EARTH_RADIUS_KM
|
| 111 |
+
|
| 112 |
+
altitudes = (t_alt + c_alt) / 2.0
|
| 113 |
+
# Clamp to valid range
|
| 114 |
+
altitudes = np.clip(altitudes, ALT_MIN_KM, ALT_MAX_KM - 1)
|
| 115 |
+
return altitudes, event_df.index.values
|
| 116 |
+
|
| 117 |
+
def fit(self, train_df: pd.DataFrame) -> "OrbitalDensityComputer":
|
| 118 |
+
"""Fit density distribution from training data.
|
| 119 |
+
|
| 120 |
+
Must be called before transform(). Only uses training data
|
| 121 |
+
to prevent information leakage into validation/test sets.
|
| 122 |
+
"""
|
| 123 |
+
altitudes, event_ids = self._event_altitude(train_df)
|
| 124 |
+
|
| 125 |
+
# Histogram: count events per altitude bin
|
| 126 |
+
self.event_counts, _ = np.histogram(altitudes, bins=self.bin_edges)
|
| 127 |
+
|
| 128 |
+
# Density: events per km³ in each shell
|
| 129 |
+
volumes = np.array([
|
| 130 |
+
_shell_volume_km3(c, self.bin_width_km)
|
| 131 |
+
for c in self.bin_centers
|
| 132 |
+
])
|
| 133 |
+
self.density_per_bin = self.event_counts / np.maximum(volumes, 1e-6)
|
| 134 |
+
|
| 135 |
+
# Collision rate per shell: Γ = (1/2) * n² * A_col * v_r * V
|
| 136 |
+
# Using satellite-satellite cross-section as the primary concern
|
| 137 |
+
self.collision_rate = np.zeros(self.n_bins)
|
| 138 |
+
for i, (center, density, volume) in enumerate(
|
| 139 |
+
zip(self.bin_centers, self.density_per_bin, volumes)
|
| 140 |
+
):
|
| 141 |
+
v_r = _mean_relative_speed_kms(center) # km/s
|
| 142 |
+
# Convert A_col from m² to km², v_r already in km/s
|
| 143 |
+
a_col_km2 = A_COL_SAT_SAT / 1e6 # m² → km²
|
| 144 |
+
# Γ = 0.5 * n² * A * v_r * V (units: per second)
|
| 145 |
+
gamma = 0.5 * density**2 * a_col_km2 * v_r * volume
|
| 146 |
+
self.collision_rate[i] = gamma
|
| 147 |
+
|
| 148 |
+
# CRASH Clock per shell: τ = 1/Γ (in seconds), log10 for feature
|
| 149 |
+
with np.errstate(divide="ignore"):
|
| 150 |
+
tau = 1.0 / np.maximum(self.collision_rate, 1e-30)
|
| 151 |
+
self.crash_clock_log = np.log10(np.clip(tau, 1.0, 1e15))
|
| 152 |
+
|
| 153 |
+
# Risk rate per bin: fraction of positive events
|
| 154 |
+
risk_per_event = train_df.groupby("event_id")["risk"].last()
|
| 155 |
+
is_high_risk = (risk_per_event > -5).astype(float).values
|
| 156 |
+
|
| 157 |
+
self.risk_rate_per_bin = np.zeros(self.n_bins)
|
| 158 |
+
for i in range(self.n_bins):
|
| 159 |
+
mask = (altitudes >= self.bin_edges[i]) & (altitudes < self.bin_edges[i + 1])
|
| 160 |
+
if mask.sum() > 0:
|
| 161 |
+
self.risk_rate_per_bin[i] = is_high_risk[mask].mean()
|
| 162 |
+
|
| 163 |
+
# Cumulative altitude distribution for percentile feature
|
| 164 |
+
sorted_alts = np.sort(altitudes)
|
| 165 |
+
self.altitude_cdf = sorted_alts
|
| 166 |
+
|
| 167 |
+
self.is_fitted = True
|
| 168 |
+
print(f" OrbitalDensityComputer fitted on {len(event_ids)} events")
|
| 169 |
+
print(f" Altitude range: {altitudes.min():.0f} - {altitudes.max():.0f} km")
|
| 170 |
+
print(f" Peak density bin: {self.bin_centers[np.argmax(self.density_per_bin)]:.0f} km "
|
| 171 |
+
f"({self.event_counts.max()} events)")
|
| 172 |
+
peak_idx = np.argmax(self.collision_rate)
|
| 173 |
+
if self.collision_rate[peak_idx] > 0:
|
| 174 |
+
print(f" Highest collision rate: {self.bin_centers[peak_idx]:.0f} km "
|
| 175 |
+
f"(tau = {10**self.crash_clock_log[peak_idx]:.0f} s)")
|
| 176 |
+
|
| 177 |
+
return self
|
| 178 |
+
|
| 179 |
+
def _get_bin_index(self, altitudes: np.ndarray) -> np.ndarray:
|
| 180 |
+
"""Map altitudes to bin indices."""
|
| 181 |
+
indices = np.digitize(altitudes, self.bin_edges) - 1
|
| 182 |
+
return np.clip(indices, 0, self.n_bins - 1)
|
| 183 |
+
|
| 184 |
+
def _altitude_percentile(self, altitudes: np.ndarray) -> np.ndarray:
|
| 185 |
+
"""Compute percentile in the training altitude distribution."""
|
| 186 |
+
return np.searchsorted(self.altitude_cdf, altitudes) / len(self.altitude_cdf)
|
| 187 |
+
|
| 188 |
+
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 189 |
+
"""Add density features to a CDM DataFrame.
|
| 190 |
+
|
| 191 |
+
Features are computed per event_id and broadcast to all CDM rows
|
| 192 |
+
(they're static features — same for every CDM in the sequence).
|
| 193 |
+
"""
|
| 194 |
+
if not self.is_fitted:
|
| 195 |
+
raise RuntimeError("Must call fit() before transform()")
|
| 196 |
+
|
| 197 |
+
df = df.copy()
|
| 198 |
+
altitudes, event_ids = self._event_altitude(df)
|
| 199 |
+
bin_indices = self._get_bin_index(altitudes)
|
| 200 |
+
|
| 201 |
+
# Build event-level features
|
| 202 |
+
event_features = {}
|
| 203 |
+
for i, eid in enumerate(event_ids):
|
| 204 |
+
bi = bin_indices[i]
|
| 205 |
+
event_features[eid] = {
|
| 206 |
+
"shell_density": self.density_per_bin[bi],
|
| 207 |
+
"shell_collision_rate": self.collision_rate[bi],
|
| 208 |
+
"local_crash_clock_log": self.crash_clock_log[bi],
|
| 209 |
+
"altitude_percentile": self._altitude_percentile(
|
| 210 |
+
np.array([altitudes[i]])
|
| 211 |
+
)[0],
|
| 212 |
+
"n_events_in_shell": float(self.event_counts[bi]),
|
| 213 |
+
"shell_risk_rate": self.risk_rate_per_bin[bi],
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
# Map features to all CDM rows via event_id
|
| 217 |
+
for col in DENSITY_FEATURES:
|
| 218 |
+
df[col] = df["event_id"].map(
|
| 219 |
+
{eid: feats[col] for eid, feats in event_features.items()}
|
| 220 |
+
).fillna(0.0)
|
| 221 |
+
|
| 222 |
+
return df
|
| 223 |
+
|
| 224 |
+
def save(self, path: Path):
|
| 225 |
+
"""Save fitted state to JSON for inference."""
|
| 226 |
+
if not self.is_fitted:
|
| 227 |
+
raise RuntimeError("Must call fit() before save()")
|
| 228 |
+
state = {
|
| 229 |
+
"bin_width_km": self.bin_width_km,
|
| 230 |
+
"bin_edges": self.bin_edges.tolist(),
|
| 231 |
+
"bin_centers": self.bin_centers.tolist(),
|
| 232 |
+
"event_counts": self.event_counts.tolist(),
|
| 233 |
+
"density_per_bin": self.density_per_bin.tolist(),
|
| 234 |
+
"collision_rate": self.collision_rate.tolist(),
|
| 235 |
+
"crash_clock_log": self.crash_clock_log.tolist(),
|
| 236 |
+
"risk_rate_per_bin": self.risk_rate_per_bin.tolist(),
|
| 237 |
+
"altitude_cdf": self.altitude_cdf.tolist(),
|
| 238 |
+
}
|
| 239 |
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
| 240 |
+
with open(path, "w") as f:
|
| 241 |
+
json.dump(state, f, indent=2)
|
| 242 |
+
|
| 243 |
+
@classmethod
|
| 244 |
+
def load(cls, path: Path) -> "OrbitalDensityComputer":
|
| 245 |
+
"""Load fitted state from JSON."""
|
| 246 |
+
with open(path) as f:
|
| 247 |
+
state = json.load(f)
|
| 248 |
+
obj = cls(bin_width_km=state["bin_width_km"])
|
| 249 |
+
obj.bin_edges = np.array(state["bin_edges"])
|
| 250 |
+
obj.bin_centers = np.array(state["bin_centers"])
|
| 251 |
+
obj.n_bins = len(obj.bin_centers)
|
| 252 |
+
obj.event_counts = np.array(state["event_counts"])
|
| 253 |
+
obj.density_per_bin = np.array(state["density_per_bin"])
|
| 254 |
+
obj.collision_rate = np.array(state["collision_rate"])
|
| 255 |
+
obj.crash_clock_log = np.array(state["crash_clock_log"])
|
| 256 |
+
obj.risk_rate_per_bin = np.array(state["risk_rate_per_bin"])
|
| 257 |
+
obj.altitude_cdf = np.array(state["altitude_cdf"])
|
| 258 |
+
obj.is_fitted = True
|
| 259 |
+
return obj
|
src/data/firebase_client.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code -- 2026-02-13
|
| 2 |
+
"""Firebase Firestore client for prediction logging.
|
| 3 |
+
|
| 4 |
+
Stores daily conjunction predictions and maneuver detection outcomes.
|
| 5 |
+
Uses the Firestore REST API to avoid heavy SDK dependencies.
|
| 6 |
+
Falls back to local JSONL logging if Firebase is not configured.
|
| 7 |
+
|
| 8 |
+
Environment variables:
|
| 9 |
+
FIREBASE_SERVICE_ACCOUNT: JSON string of the service account key
|
| 10 |
+
FIREBASE_PROJECT_ID: Project ID (auto-detected from service account if not set)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import json
|
| 15 |
+
import time
|
| 16 |
+
import numpy as np
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from datetime import datetime, timezone
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _json_default(obj):
|
| 22 |
+
"""Handle numpy types that json.dumps can't serialize."""
|
| 23 |
+
if isinstance(obj, (np.integer,)):
|
| 24 |
+
return int(obj)
|
| 25 |
+
if isinstance(obj, (np.floating,)):
|
| 26 |
+
return float(obj)
|
| 27 |
+
if isinstance(obj, (np.bool_,)):
|
| 28 |
+
return bool(obj)
|
| 29 |
+
if isinstance(obj, np.ndarray):
|
| 30 |
+
return obj.tolist()
|
| 31 |
+
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
|
| 32 |
+
|
| 33 |
+
# Try to import google-cloud-firestore (lightweight)
|
| 34 |
+
try:
|
| 35 |
+
from google.cloud.firestore import Client as FirestoreClient
|
| 36 |
+
from google.oauth2.service_account import Credentials
|
| 37 |
+
HAS_FIRESTORE = True
|
| 38 |
+
except ImportError:
|
| 39 |
+
HAS_FIRESTORE = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class PredictionLogger:
|
| 43 |
+
"""Log predictions to Firebase Firestore with local JSONL fallback."""
|
| 44 |
+
|
| 45 |
+
def __init__(self, local_dir: Path = None):
|
| 46 |
+
self.db = None
|
| 47 |
+
self.local_dir = local_dir or Path("data/prediction_logs")
|
| 48 |
+
self.local_dir.mkdir(parents=True, exist_ok=True)
|
| 49 |
+
self._init_firebase()
|
| 50 |
+
|
| 51 |
+
def _init_firebase(self):
|
| 52 |
+
"""Initialize Firebase Firestore client from environment."""
|
| 53 |
+
sa_json = os.environ.get("FIREBASE_SERVICE_ACCOUNT", "")
|
| 54 |
+
if not sa_json or not HAS_FIRESTORE:
|
| 55 |
+
if not HAS_FIRESTORE:
|
| 56 |
+
print(" Firebase SDK not installed (pip install google-cloud-firestore)")
|
| 57 |
+
print(" Using local JSONL logging only")
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
sa_info = json.loads(sa_json)
|
| 62 |
+
creds = Credentials.from_service_account_info(sa_info)
|
| 63 |
+
project_id = sa_info.get("project_id", os.environ.get("FIREBASE_PROJECT_ID", ""))
|
| 64 |
+
self.db = FirestoreClient(project=project_id, credentials=creds)
|
| 65 |
+
print(f" Firebase Firestore connected (project: {project_id})")
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f" Firebase init failed: {e}")
|
| 68 |
+
print(" Falling back to local JSONL logging")
|
| 69 |
+
|
| 70 |
+
def log_predictions(self, date_str: str, predictions: list[dict]):
|
| 71 |
+
"""Log a batch of daily predictions.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
date_str: Date string (YYYY-MM-DD)
|
| 75 |
+
predictions: List of prediction dicts with keys:
|
| 76 |
+
sat1_norad, sat2_norad, sat1_name, sat2_name,
|
| 77 |
+
risk_score, altitude_km, model_used
|
| 78 |
+
"""
|
| 79 |
+
# Always save locally
|
| 80 |
+
local_file = self.local_dir / f"{date_str}.jsonl"
|
| 81 |
+
with open(local_file, "a") as f:
|
| 82 |
+
for pred in predictions:
|
| 83 |
+
pred["date"] = date_str
|
| 84 |
+
pred["logged_at"] = datetime.now(timezone.utc).isoformat()
|
| 85 |
+
f.write(json.dumps(pred, default=_json_default) + "\n")
|
| 86 |
+
print(f" Saved {len(predictions)} predictions to {local_file}")
|
| 87 |
+
|
| 88 |
+
# Firebase upload
|
| 89 |
+
if self.db:
|
| 90 |
+
try:
|
| 91 |
+
batch = self.db.batch()
|
| 92 |
+
collection = self.db.collection("predictions").document(date_str)
|
| 93 |
+
collection.set({"date": date_str, "count": len(predictions)})
|
| 94 |
+
|
| 95 |
+
for i, pred in enumerate(predictions):
|
| 96 |
+
doc_ref = self.db.collection("predictions").document(date_str) \
|
| 97 |
+
.collection("pairs").document(f"pair_{i:04d}")
|
| 98 |
+
batch.set(doc_ref, pred)
|
| 99 |
+
|
| 100 |
+
batch.commit()
|
| 101 |
+
print(f" Uploaded {len(predictions)} predictions to Firebase")
|
| 102 |
+
except Exception as e:
|
| 103 |
+
print(f" Firebase upload failed: {e}")
|
| 104 |
+
|
| 105 |
+
def log_outcomes(self, date_str: str, outcomes: list[dict]):
|
| 106 |
+
"""Log maneuver detection outcomes for a previous prediction date.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
date_str: Original prediction date (YYYY-MM-DD)
|
| 110 |
+
outcomes: List of outcome dicts with keys:
|
| 111 |
+
sat1_norad, sat2_norad, sat1_maneuvered, sat2_maneuvered,
|
| 112 |
+
sat1_delta_a_m, sat2_delta_a_m, validated_at
|
| 113 |
+
"""
|
| 114 |
+
local_file = self.local_dir / f"{date_str}_outcomes.jsonl"
|
| 115 |
+
with open(local_file, "a") as f:
|
| 116 |
+
for outcome in outcomes:
|
| 117 |
+
outcome["prediction_date"] = date_str
|
| 118 |
+
outcome["validated_at"] = datetime.now(timezone.utc).isoformat()
|
| 119 |
+
f.write(json.dumps(outcome, default=_json_default) + "\n")
|
| 120 |
+
print(f" Saved {len(outcomes)} outcomes to {local_file}")
|
| 121 |
+
|
| 122 |
+
if self.db:
|
| 123 |
+
try:
|
| 124 |
+
batch = self.db.batch()
|
| 125 |
+
for i, outcome in enumerate(outcomes):
|
| 126 |
+
doc_ref = self.db.collection("outcomes").document(date_str) \
|
| 127 |
+
.collection("results").document(f"result_{i:04d}")
|
| 128 |
+
batch.set(doc_ref, outcome)
|
| 129 |
+
batch.commit()
|
| 130 |
+
print(f" Uploaded {len(outcomes)} outcomes to Firebase")
|
| 131 |
+
except Exception as e:
|
| 132 |
+
print(f" Firebase upload failed: {e}")
|
| 133 |
+
|
| 134 |
+
def log_daily_summary(self, date_str: str, summary: dict):
|
| 135 |
+
"""Log a daily summary (n_predictions, n_maneuvers_detected, accuracy, etc)."""
|
| 136 |
+
local_file = self.local_dir / "daily_summaries.jsonl"
|
| 137 |
+
summary["date"] = date_str
|
| 138 |
+
with open(local_file, "a") as f:
|
| 139 |
+
f.write(json.dumps(summary, default=_json_default) + "\n")
|
| 140 |
+
|
| 141 |
+
if self.db:
|
| 142 |
+
try:
|
| 143 |
+
self.db.collection("daily_summaries").document(date_str).set(summary)
|
| 144 |
+
print(f" Uploaded daily summary to Firebase")
|
| 145 |
+
except Exception as e:
|
| 146 |
+
print(f" Firebase summary upload failed: {e}")
|
| 147 |
+
|
| 148 |
+
def get_predictions_for_date(self, date_str: str) -> list[dict]:
|
| 149 |
+
"""Retrieve predictions for a date (from local files)."""
|
| 150 |
+
local_file = self.local_dir / f"{date_str}.jsonl"
|
| 151 |
+
if not local_file.exists():
|
| 152 |
+
return []
|
| 153 |
+
predictions = []
|
| 154 |
+
with open(local_file) as f:
|
| 155 |
+
for line in f:
|
| 156 |
+
line = line.strip()
|
| 157 |
+
if line:
|
| 158 |
+
predictions.append(json.loads(line))
|
| 159 |
+
return predictions
|
src/data/maneuver_classifier.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Classify detected satellite maneuvers into avoidance vs routine.
|
| 2 |
+
|
| 3 |
+
Enriches each maneuver with:
|
| 4 |
+
- magnitude_class: micro/small/medium/large based on delta-v
|
| 5 |
+
- constellation: starlink/oneweb/iridium/other
|
| 6 |
+
- is_stationkeeping: regularity-based detection from maneuver history
|
| 7 |
+
- likely_avoidance: heuristic combining all signals
|
| 8 |
+
|
| 9 |
+
These enrichments improve training label quality for PI-TFT fine-tuning
|
| 10 |
+
without changing the model's feature space.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import re
|
| 14 |
+
import numpy as np
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Delta-v magnitude bins (m/s)
|
| 19 |
+
MAGNITUDE_BINS = [
|
| 20 |
+
("micro", 0.0, 0.5),
|
| 21 |
+
("small", 0.5, 2.0),
|
| 22 |
+
("medium", 2.0, 10.0),
|
| 23 |
+
("large", 10.0, float("inf")),
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
# Constellation name patterns
|
| 27 |
+
CONSTELLATION_PATTERNS = [
|
| 28 |
+
("starlink", re.compile(r"STARLINK", re.IGNORECASE)),
|
| 29 |
+
("oneweb", re.compile(r"ONEWEB", re.IGNORECASE)),
|
| 30 |
+
("iridium", re.compile(r"IRIDIUM", re.IGNORECASE)),
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
# Stationkeeping regularity threshold (coefficient of variation of intervals)
|
| 34 |
+
STATIONKEEPING_CV_THRESHOLD = 0.3
|
| 35 |
+
MIN_HISTORY_FOR_SK = 3 # Need at least 3 past maneuvers to detect pattern
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def classify_magnitude(delta_v_m_s: float) -> str:
|
| 39 |
+
"""Bin delta-v into magnitude class."""
|
| 40 |
+
dv = abs(delta_v_m_s)
|
| 41 |
+
for label, lo, hi in MAGNITUDE_BINS:
|
| 42 |
+
if lo <= dv < hi:
|
| 43 |
+
return label
|
| 44 |
+
return "large"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def detect_constellation(name: str) -> str:
|
| 48 |
+
"""Identify constellation from satellite name."""
|
| 49 |
+
for constellation, pattern in CONSTELLATION_PATTERNS:
|
| 50 |
+
if pattern.search(name):
|
| 51 |
+
return constellation
|
| 52 |
+
return "other"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def detect_stationkeeping(history: list[dict]) -> bool:
|
| 56 |
+
"""Detect stationkeeping from regularity of past maneuver intervals.
|
| 57 |
+
|
| 58 |
+
If the coefficient of variation (std/mean) of time intervals between
|
| 59 |
+
consecutive maneuvers is below threshold, it's likely stationkeeping.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
history: Past maneuver records for this NORAD ID, each with
|
| 63 |
+
'detected_at' ISO timestamp.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
True if maneuver pattern suggests stationkeeping.
|
| 67 |
+
"""
|
| 68 |
+
if not history or len(history) < MIN_HISTORY_FOR_SK:
|
| 69 |
+
return False
|
| 70 |
+
|
| 71 |
+
# Parse timestamps and sort
|
| 72 |
+
timestamps = []
|
| 73 |
+
for h in history:
|
| 74 |
+
ts_str = h.get("detected_at", "")
|
| 75 |
+
if not ts_str:
|
| 76 |
+
continue
|
| 77 |
+
try:
|
| 78 |
+
ts = datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
|
| 79 |
+
timestamps.append(ts.timestamp())
|
| 80 |
+
except (ValueError, TypeError):
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
if len(timestamps) < MIN_HISTORY_FOR_SK:
|
| 84 |
+
return False
|
| 85 |
+
|
| 86 |
+
timestamps.sort()
|
| 87 |
+
intervals = np.diff(timestamps)
|
| 88 |
+
|
| 89 |
+
if len(intervals) < 2:
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
mean_interval = np.mean(intervals)
|
| 93 |
+
if mean_interval <= 0:
|
| 94 |
+
return False
|
| 95 |
+
|
| 96 |
+
cv = np.std(intervals) / mean_interval
|
| 97 |
+
return cv < STATIONKEEPING_CV_THRESHOLD
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def classify_maneuver(maneuver: dict, history: list[dict] = None) -> dict:
|
| 101 |
+
"""Classify a detected maneuver with enrichment flags.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
maneuver: Maneuver dict from detect_maneuvers() with keys:
|
| 105 |
+
norad_id, name, delta_v_m_s, delta_a_m, etc.
|
| 106 |
+
history: Past maneuver records for same NORAD ID (optional).
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Dict with enrichment fields added to the original maneuver.
|
| 110 |
+
"""
|
| 111 |
+
delta_v = maneuver.get("delta_v_m_s", 0.0)
|
| 112 |
+
name = maneuver.get("name", "")
|
| 113 |
+
|
| 114 |
+
magnitude_class = classify_magnitude(delta_v)
|
| 115 |
+
constellation = detect_constellation(name)
|
| 116 |
+
is_sk = detect_stationkeeping(history) if history else False
|
| 117 |
+
|
| 118 |
+
# Likely avoidance heuristic
|
| 119 |
+
likely_avoidance = False
|
| 120 |
+
|
| 121 |
+
if not is_sk and magnitude_class in ("micro", "small") and delta_v < 5.0:
|
| 122 |
+
likely_avoidance = True
|
| 123 |
+
|
| 124 |
+
# Starlink CAMs are typically very small (< 1 m/s)
|
| 125 |
+
if constellation == "starlink" and delta_v < 1.0:
|
| 126 |
+
likely_avoidance = True
|
| 127 |
+
|
| 128 |
+
enriched = dict(maneuver)
|
| 129 |
+
enriched.update({
|
| 130 |
+
"magnitude_class": magnitude_class,
|
| 131 |
+
"constellation": constellation,
|
| 132 |
+
"is_stationkeeping": is_sk,
|
| 133 |
+
"likely_avoidance": likely_avoidance,
|
| 134 |
+
"enrichment_version": 1,
|
| 135 |
+
# Phase B/C defaults — overwritten later if data is available
|
| 136 |
+
"has_cdm": False,
|
| 137 |
+
"cdm_pc": None,
|
| 138 |
+
"cdm_miss_distance_km": None,
|
| 139 |
+
"counterfactual_min_distance_km": None,
|
| 140 |
+
"would_have_collided": False,
|
| 141 |
+
"counterfactual_closest_norad": None,
|
| 142 |
+
})
|
| 143 |
+
return enriched
|
src/data/maneuver_detector.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code -- 2026-02-13
|
| 2 |
+
"""Detect satellite maneuvers from TLE data changes.
|
| 3 |
+
|
| 4 |
+
Compares successive TLEs for the same satellite. An abrupt change in
|
| 5 |
+
semi-major axis (> threshold) indicates a maneuver — either collision
|
| 6 |
+
avoidance, orbit maintenance, or orbit raising.
|
| 7 |
+
|
| 8 |
+
Based on Kelecy (2007) and Patera & Peterson (2021).
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import math
|
| 13 |
+
import numpy as np
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from datetime import datetime, timedelta, timezone
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Earth parameters (WGS84)
|
| 19 |
+
MU_EARTH = 398600.4418 # km^3/s^2
|
| 20 |
+
EARTH_RADIUS_KM = 6378.137
|
| 21 |
+
|
| 22 |
+
# Maneuver detection thresholds
|
| 23 |
+
DEFAULT_DELTA_A_THRESHOLD_M = 200 # meters — below this is noise
|
| 24 |
+
STARLINK_DELTA_A_THRESHOLD_M = 100 # Starlink maneuvers can be smaller
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def mean_motion_to_sma(n_rev_per_day: float) -> float:
|
| 28 |
+
"""Convert mean motion (rev/day) to semi-major axis (km)."""
|
| 29 |
+
if n_rev_per_day <= 0:
|
| 30 |
+
return 0.0
|
| 31 |
+
n_rad_per_sec = n_rev_per_day * 2 * math.pi / 86400.0
|
| 32 |
+
return (MU_EARTH / (n_rad_per_sec ** 2)) ** (1.0 / 3.0)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def sma_to_altitude(sma_km: float) -> float:
|
| 36 |
+
"""Convert semi-major axis to approximate altitude (km)."""
|
| 37 |
+
return sma_km - EARTH_RADIUS_KM
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def parse_tle_epoch(epoch_str: str) -> datetime:
|
| 41 |
+
"""Parse a CelesTrak JSON epoch string (ISO 8601 format)."""
|
| 42 |
+
# CelesTrak uses: "2026-02-13T12:00:00.000000"
|
| 43 |
+
for fmt in ("%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d"):
|
| 44 |
+
try:
|
| 45 |
+
return datetime.strptime(epoch_str, fmt)
|
| 46 |
+
except ValueError:
|
| 47 |
+
continue
|
| 48 |
+
raise ValueError(f"Cannot parse epoch: {epoch_str}")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def extract_orbital_elements(tle_json: dict) -> dict:
|
| 52 |
+
"""Extract key orbital elements from a CelesTrak JSON TLE entry."""
|
| 53 |
+
norad_id = int(tle_json.get("NORAD_CAT_ID", 0))
|
| 54 |
+
name = tle_json.get("OBJECT_NAME", "UNKNOWN")
|
| 55 |
+
mean_motion = float(tle_json.get("MEAN_MOTION", 0))
|
| 56 |
+
eccentricity = float(tle_json.get("ECCENTRICITY", 0))
|
| 57 |
+
inclination = float(tle_json.get("INCLINATION", 0))
|
| 58 |
+
raan = float(tle_json.get("RA_OF_ASC_NODE", 0))
|
| 59 |
+
epoch_str = tle_json.get("EPOCH", "")
|
| 60 |
+
|
| 61 |
+
sma = mean_motion_to_sma(mean_motion)
|
| 62 |
+
altitude = sma_to_altitude(sma)
|
| 63 |
+
|
| 64 |
+
epoch = None
|
| 65 |
+
if epoch_str:
|
| 66 |
+
try:
|
| 67 |
+
epoch = parse_tle_epoch(epoch_str)
|
| 68 |
+
except ValueError:
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
return {
|
| 72 |
+
"norad_id": norad_id,
|
| 73 |
+
"name": name,
|
| 74 |
+
"mean_motion": mean_motion,
|
| 75 |
+
"eccentricity": eccentricity,
|
| 76 |
+
"inclination": inclination,
|
| 77 |
+
"raan": raan,
|
| 78 |
+
"sma_km": sma,
|
| 79 |
+
"altitude_km": altitude,
|
| 80 |
+
"epoch": epoch,
|
| 81 |
+
"epoch_str": epoch_str,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def detect_maneuvers(
|
| 86 |
+
prev_tles: list[dict],
|
| 87 |
+
curr_tles: list[dict],
|
| 88 |
+
threshold_m: float = DEFAULT_DELTA_A_THRESHOLD_M,
|
| 89 |
+
) -> list[dict]:
|
| 90 |
+
"""Compare two TLE snapshots and detect maneuvers.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
prev_tles: Previous TLE snapshot (CelesTrak JSON format)
|
| 94 |
+
curr_tles: Current TLE snapshot (CelesTrak JSON format)
|
| 95 |
+
threshold_m: Semi-major axis change threshold in meters
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
List of detected maneuvers with details
|
| 99 |
+
"""
|
| 100 |
+
# Index previous TLEs by NORAD ID
|
| 101 |
+
prev_by_id = {}
|
| 102 |
+
for tle in prev_tles:
|
| 103 |
+
elem = extract_orbital_elements(tle)
|
| 104 |
+
if elem["norad_id"] > 0 and elem["sma_km"] > 0:
|
| 105 |
+
prev_by_id[elem["norad_id"]] = elem
|
| 106 |
+
|
| 107 |
+
maneuvers = []
|
| 108 |
+
for tle in curr_tles:
|
| 109 |
+
elem = extract_orbital_elements(tle)
|
| 110 |
+
norad_id = elem["norad_id"]
|
| 111 |
+
|
| 112 |
+
if norad_id not in prev_by_id or elem["sma_km"] <= 0:
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
prev = prev_by_id[norad_id]
|
| 116 |
+
delta_a_km = elem["sma_km"] - prev["sma_km"]
|
| 117 |
+
delta_a_m = abs(delta_a_km) * 1000
|
| 118 |
+
|
| 119 |
+
if delta_a_m > threshold_m:
|
| 120 |
+
# Classify maneuver type
|
| 121 |
+
if delta_a_km > 0:
|
| 122 |
+
maneuver_type = "orbit_raise"
|
| 123 |
+
else:
|
| 124 |
+
maneuver_type = "orbit_lower"
|
| 125 |
+
|
| 126 |
+
# Estimate delta-v (Hohmann approximation)
|
| 127 |
+
v_circular = math.sqrt(MU_EARTH / prev["sma_km"]) # km/s
|
| 128 |
+
delta_v = abs(delta_a_km) / (2 * prev["sma_km"]) * v_circular * 1000 # m/s
|
| 129 |
+
|
| 130 |
+
maneuvers.append({
|
| 131 |
+
"norad_id": norad_id,
|
| 132 |
+
"name": elem["name"],
|
| 133 |
+
"prev_sma_km": prev["sma_km"],
|
| 134 |
+
"curr_sma_km": elem["sma_km"],
|
| 135 |
+
"delta_a_m": delta_a_m,
|
| 136 |
+
"delta_a_km": delta_a_km,
|
| 137 |
+
"delta_v_m_s": round(delta_v, 3),
|
| 138 |
+
"maneuver_type": maneuver_type,
|
| 139 |
+
"altitude_km": elem["altitude_km"],
|
| 140 |
+
"prev_epoch": prev["epoch_str"],
|
| 141 |
+
"curr_epoch": elem["epoch_str"],
|
| 142 |
+
"detected_at": datetime.now(timezone.utc).isoformat(),
|
| 143 |
+
})
|
| 144 |
+
|
| 145 |
+
# Sort by delta_a descending (largest maneuvers first)
|
| 146 |
+
maneuvers.sort(key=lambda m: m["delta_a_m"], reverse=True)
|
| 147 |
+
return maneuvers
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def detect_maneuvers_dual_threshold(
|
| 151 |
+
prev_tles: list[dict],
|
| 152 |
+
curr_tles: list[dict],
|
| 153 |
+
) -> list[dict]:
|
| 154 |
+
"""Detect maneuvers using constellation-aware thresholds.
|
| 155 |
+
|
| 156 |
+
Uses 100m threshold for Starlink (smaller maneuvers) and
|
| 157 |
+
200m for everything else. Merges results, deduplicating by NORAD ID.
|
| 158 |
+
"""
|
| 159 |
+
# Split current TLEs by constellation
|
| 160 |
+
starlink_curr = []
|
| 161 |
+
other_curr = []
|
| 162 |
+
for tle in curr_tles:
|
| 163 |
+
name = tle.get("OBJECT_NAME", "")
|
| 164 |
+
if "STARLINK" in name.upper():
|
| 165 |
+
starlink_curr.append(tle)
|
| 166 |
+
else:
|
| 167 |
+
other_curr.append(tle)
|
| 168 |
+
|
| 169 |
+
# Split previous TLEs the same way
|
| 170 |
+
starlink_prev = []
|
| 171 |
+
other_prev = []
|
| 172 |
+
for tle in prev_tles:
|
| 173 |
+
name = tle.get("OBJECT_NAME", "")
|
| 174 |
+
if "STARLINK" in name.upper():
|
| 175 |
+
starlink_prev.append(tle)
|
| 176 |
+
else:
|
| 177 |
+
other_prev.append(tle)
|
| 178 |
+
|
| 179 |
+
# Detect with appropriate thresholds
|
| 180 |
+
starlink_maneuvers = detect_maneuvers(
|
| 181 |
+
starlink_prev, starlink_curr,
|
| 182 |
+
threshold_m=STARLINK_DELTA_A_THRESHOLD_M,
|
| 183 |
+
)
|
| 184 |
+
other_maneuvers = detect_maneuvers(
|
| 185 |
+
other_prev, other_curr,
|
| 186 |
+
threshold_m=DEFAULT_DELTA_A_THRESHOLD_M,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Merge and sort by delta_a descending
|
| 190 |
+
all_maneuvers = starlink_maneuvers + other_maneuvers
|
| 191 |
+
all_maneuvers.sort(key=lambda m: m["delta_a_m"], reverse=True)
|
| 192 |
+
return all_maneuvers
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def load_tle_snapshot(path: Path) -> list[dict]:
|
| 196 |
+
"""Load a TLE snapshot from a JSON file."""
|
| 197 |
+
with open(path) as f:
|
| 198 |
+
return json.load(f)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def save_tle_snapshot(tles: list[dict], path: Path):
|
| 202 |
+
"""Save a TLE snapshot to a JSON file."""
|
| 203 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 204 |
+
with open(path, "w") as f:
|
| 205 |
+
json.dump(tles, f)
|
src/data/merge_sources.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code -- 2026-02-08
|
| 2 |
+
"""Merge CDM data from multiple sources into unified training format.
|
| 3 |
+
|
| 4 |
+
Combines:
|
| 5 |
+
1. ESA Kelvins dataset (103 features, labeled)
|
| 6 |
+
2. Space-Track cdm_public (16 features, unlabeled — derive risk from PC)
|
| 7 |
+
|
| 8 |
+
Strategy:
|
| 9 |
+
- Space-Track CDMs are grouped into "conjunction events" by (SAT_1_ID, SAT_2_ID, TCA_date)
|
| 10 |
+
- Each event gets a time series of CDMs ordered by CREATED date
|
| 11 |
+
- Risk label derived from final PC: high risk if PC > 1e-5 (same threshold as Kelvins)
|
| 12 |
+
- Features that exist in both sources get unified column names
|
| 13 |
+
- Missing features (e.g., covariance in Space-Track) are filled with 0
|
| 14 |
+
|
| 15 |
+
This gives us far more positive examples for training the risk classifier,
|
| 16 |
+
even though the Space-Track data has fewer features per CDM.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import pandas as pd
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from datetime import timedelta
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Mapping from Space-Track CDM_PUBLIC fields → unified column names
|
| 26 |
+
SPACETRACK_COLUMN_MAP = {
|
| 27 |
+
"CDM_ID": "cdm_id",
|
| 28 |
+
"CREATED": "created",
|
| 29 |
+
"TCA": "tca",
|
| 30 |
+
"MIN_RNG": "miss_distance", # km in Space-Track
|
| 31 |
+
"PC": "collision_probability",
|
| 32 |
+
"SAT_1_ID": "sat_1_id",
|
| 33 |
+
"SAT_1_NAME": "sat_1_name",
|
| 34 |
+
"SAT1_OBJECT_TYPE": "t_object_type",
|
| 35 |
+
"SAT1_RCS": "t_rcs",
|
| 36 |
+
"SAT_1_EXCL_VOL": "t_excl_vol",
|
| 37 |
+
"SAT_2_ID": "sat_2_id",
|
| 38 |
+
"SAT_2_NAME": "sat_2_name",
|
| 39 |
+
"SAT2_OBJECT_TYPE": "c_object_type",
|
| 40 |
+
"SAT2_RCS": "c_rcs",
|
| 41 |
+
"SAT_2_EXCL_VOL": "c_excl_vol",
|
| 42 |
+
"EMERGENCY_REPORTABLE": "emergency_reportable",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
# Risk threshold: PC > 1e-5 = high risk (matches ESA Kelvins: risk > -5)
|
| 46 |
+
RISK_THRESHOLD = 1e-5
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_spacetrack_cdms(csv_path: Path) -> pd.DataFrame:
|
| 50 |
+
"""Load Space-Track CDM CSV and do initial cleaning."""
|
| 51 |
+
df = pd.read_csv(csv_path)
|
| 52 |
+
|
| 53 |
+
# Rename columns to unified format
|
| 54 |
+
df = df.rename(columns=SPACETRACK_COLUMN_MAP)
|
| 55 |
+
|
| 56 |
+
# Parse dates
|
| 57 |
+
for col in ["created", "tca"]:
|
| 58 |
+
if col in df.columns:
|
| 59 |
+
df[col] = pd.to_datetime(df[col], errors="coerce")
|
| 60 |
+
|
| 61 |
+
# Convert miss_distance to float
|
| 62 |
+
if "miss_distance" in df.columns:
|
| 63 |
+
df["miss_distance"] = pd.to_numeric(df["miss_distance"], errors="coerce")
|
| 64 |
+
# Space-Track MIN_RNG is in km; ESA Kelvins miss_distance is in meters
|
| 65 |
+
# Convert to meters for consistency
|
| 66 |
+
df["miss_distance"] = df["miss_distance"] * 1000.0
|
| 67 |
+
|
| 68 |
+
# Convert collision_probability to float
|
| 69 |
+
if "collision_probability" in df.columns:
|
| 70 |
+
df["collision_probability"] = pd.to_numeric(df["collision_probability"], errors="coerce")
|
| 71 |
+
|
| 72 |
+
# Derive risk column (log10 of PC, matching ESA format)
|
| 73 |
+
if "collision_probability" in df.columns:
|
| 74 |
+
df["risk"] = np.where(
|
| 75 |
+
df["collision_probability"] > 0,
|
| 76 |
+
np.log10(df["collision_probability"].clip(lower=1e-30)),
|
| 77 |
+
-30.0,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
print(f"Loaded {len(df)} Space-Track CDMs from {csv_path.name}")
|
| 81 |
+
return df
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def group_into_events(df: pd.DataFrame) -> pd.DataFrame:
|
| 85 |
+
"""
|
| 86 |
+
Group Space-Track CDMs into conjunction events.
|
| 87 |
+
|
| 88 |
+
An 'event' is a sequence of CDMs for the same object pair with TCA
|
| 89 |
+
values within 1 day of each other. Each event gets a unique event_id.
|
| 90 |
+
"""
|
| 91 |
+
if df.empty:
|
| 92 |
+
return df
|
| 93 |
+
|
| 94 |
+
# Sort by object pair and TCA
|
| 95 |
+
df = df.sort_values(["sat_1_id", "sat_2_id", "tca", "created"]).reset_index(drop=True)
|
| 96 |
+
|
| 97 |
+
# Assign event IDs: same pair + TCA within 1 day = same event
|
| 98 |
+
event_ids = []
|
| 99 |
+
current_event = 0
|
| 100 |
+
prev_sat1 = None
|
| 101 |
+
prev_sat2 = None
|
| 102 |
+
prev_tca = None
|
| 103 |
+
|
| 104 |
+
for _, row in df.iterrows():
|
| 105 |
+
sat1 = row.get("sat_1_id")
|
| 106 |
+
sat2 = row.get("sat_2_id")
|
| 107 |
+
tca = row.get("tca")
|
| 108 |
+
|
| 109 |
+
same_pair = (sat1 == prev_sat1 and sat2 == prev_sat2)
|
| 110 |
+
close_tca = False
|
| 111 |
+
if same_pair and prev_tca is not None and pd.notna(tca) and pd.notna(prev_tca):
|
| 112 |
+
close_tca = abs((tca - prev_tca).total_seconds()) < 86400 # 1 day
|
| 113 |
+
|
| 114 |
+
if not (same_pair and close_tca):
|
| 115 |
+
current_event += 1
|
| 116 |
+
|
| 117 |
+
event_ids.append(current_event)
|
| 118 |
+
prev_sat1 = sat1
|
| 119 |
+
prev_sat2 = sat2
|
| 120 |
+
prev_tca = tca
|
| 121 |
+
|
| 122 |
+
df["event_id"] = event_ids
|
| 123 |
+
|
| 124 |
+
# Compute time_to_tca: days from CDM creation to TCA (for each CDM in event)
|
| 125 |
+
if "created" in df.columns and "tca" in df.columns:
|
| 126 |
+
df["time_to_tca"] = (df["tca"] - df["created"]).dt.total_seconds() / 86400.0
|
| 127 |
+
df["time_to_tca"] = df["time_to_tca"].clip(lower=0.0)
|
| 128 |
+
|
| 129 |
+
n_events = df["event_id"].nunique()
|
| 130 |
+
n_high_risk = 0
|
| 131 |
+
if "risk" in df.columns:
|
| 132 |
+
event_risks = df.groupby("event_id")["risk"].last()
|
| 133 |
+
n_high_risk = (event_risks > -5).sum()
|
| 134 |
+
|
| 135 |
+
print(f"Grouped into {n_events} events ({n_high_risk} high-risk)")
|
| 136 |
+
return df
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def compute_relative_speed_from_excl_vol(df: pd.DataFrame) -> pd.DataFrame:
|
| 140 |
+
"""Estimate relative speed from exclusion volumes if available."""
|
| 141 |
+
# excl_vol is in km, but we can't derive speed from it alone
|
| 142 |
+
# Just ensure the column exists for compatibility
|
| 143 |
+
if "relative_speed" not in df.columns:
|
| 144 |
+
df["relative_speed"] = 0.0
|
| 145 |
+
return df
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def align_with_kelvins_schema(
|
| 149 |
+
spacetrack_df: pd.DataFrame,
|
| 150 |
+
kelvins_df: pd.DataFrame,
|
| 151 |
+
) -> pd.DataFrame:
|
| 152 |
+
"""
|
| 153 |
+
Align Space-Track data columns with Kelvins schema.
|
| 154 |
+
Missing columns get filled with 0.
|
| 155 |
+
"""
|
| 156 |
+
# Get all columns from Kelvins
|
| 157 |
+
kelvins_cols = set(kelvins_df.columns)
|
| 158 |
+
st_cols = set(spacetrack_df.columns)
|
| 159 |
+
|
| 160 |
+
# Add missing numeric columns as 0
|
| 161 |
+
for col in kelvins_cols:
|
| 162 |
+
if col not in st_cols:
|
| 163 |
+
spacetrack_df[col] = 0.0
|
| 164 |
+
|
| 165 |
+
# Keep only columns that exist in Kelvins + our extra metadata
|
| 166 |
+
extra_cols = {"sat_1_id", "sat_2_id", "sat_1_name", "sat_2_name",
|
| 167 |
+
"t_object_type", "collision_probability", "created", "tca",
|
| 168 |
+
"cdm_id", "emergency_reportable", "t_rcs", "c_rcs",
|
| 169 |
+
"t_excl_vol", "c_excl_vol", "source"}
|
| 170 |
+
keep_cols = list(kelvins_cols | extra_cols)
|
| 171 |
+
available = [c for c in keep_cols if c in spacetrack_df.columns]
|
| 172 |
+
return spacetrack_df[available]
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def merge_datasets(
|
| 176 |
+
kelvins_train_df: pd.DataFrame,
|
| 177 |
+
spacetrack_df: pd.DataFrame,
|
| 178 |
+
offset_event_ids: bool = True,
|
| 179 |
+
) -> pd.DataFrame:
|
| 180 |
+
"""
|
| 181 |
+
Merge Kelvins training data with Space-Track CDMs.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
kelvins_train_df: ESA Kelvins training DataFrame
|
| 185 |
+
spacetrack_df: Space-Track CDMs (already grouped into events)
|
| 186 |
+
offset_event_ids: shift Space-Track event_ids to avoid collisions
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Combined DataFrame ready for model training
|
| 190 |
+
"""
|
| 191 |
+
# Tag sources
|
| 192 |
+
kelvins_train_df = kelvins_train_df.copy()
|
| 193 |
+
kelvins_train_df["source"] = "kelvins"
|
| 194 |
+
|
| 195 |
+
spacetrack_df = spacetrack_df.copy()
|
| 196 |
+
spacetrack_df["source"] = "spacetrack"
|
| 197 |
+
|
| 198 |
+
# Offset Space-Track event IDs to avoid collision with Kelvins IDs
|
| 199 |
+
if offset_event_ids and "event_id" in kelvins_train_df.columns:
|
| 200 |
+
max_kelvins_id = kelvins_train_df["event_id"].max()
|
| 201 |
+
spacetrack_df["event_id"] = spacetrack_df["event_id"] + max_kelvins_id + 1
|
| 202 |
+
|
| 203 |
+
# Align columns
|
| 204 |
+
spacetrack_df = align_with_kelvins_schema(spacetrack_df, kelvins_train_df)
|
| 205 |
+
|
| 206 |
+
# Concatenate
|
| 207 |
+
combined = pd.concat([kelvins_train_df, spacetrack_df], ignore_index=True)
|
| 208 |
+
|
| 209 |
+
# Fill any remaining NaN
|
| 210 |
+
numeric_cols = combined.select_dtypes(include=[np.number]).columns
|
| 211 |
+
combined[numeric_cols] = combined[numeric_cols].fillna(0)
|
| 212 |
+
|
| 213 |
+
n_kelvins = kelvins_train_df["event_id"].nunique()
|
| 214 |
+
n_st = spacetrack_df["event_id"].nunique()
|
| 215 |
+
n_total = combined["event_id"].nunique()
|
| 216 |
+
|
| 217 |
+
# Count high-risk events per source
|
| 218 |
+
event_risk = combined.groupby(["event_id", "source"])["risk"].last().reset_index()
|
| 219 |
+
n_hr_kelvins = ((event_risk["source"] == "kelvins") & (event_risk["risk"] > -5)).sum()
|
| 220 |
+
n_hr_st = ((event_risk["source"] == "spacetrack") & (event_risk["risk"] > -5)).sum()
|
| 221 |
+
|
| 222 |
+
print(f"\nMerged dataset:")
|
| 223 |
+
print(f" Kelvins: {n_kelvins} events ({n_hr_kelvins} high-risk)")
|
| 224 |
+
print(f" Space-Track: {n_st} events ({n_hr_st} high-risk)")
|
| 225 |
+
print(f" Total: {n_total} events ({n_hr_kelvins + n_hr_st} high-risk)")
|
| 226 |
+
print(f" Columns: {len(combined.columns)}")
|
| 227 |
+
|
| 228 |
+
return combined
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def load_and_merge_all(data_dir: Path) -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 232 |
+
"""
|
| 233 |
+
Load all available data sources and merge into train/test DataFrames.
|
| 234 |
+
|
| 235 |
+
Returns (train_df, test_df) — test is Kelvins-only (for fair comparison).
|
| 236 |
+
"""
|
| 237 |
+
from src.data.cdm_loader import load_dataset
|
| 238 |
+
|
| 239 |
+
# Load ESA Kelvins
|
| 240 |
+
kelvins_dir = data_dir / "cdm"
|
| 241 |
+
kelvins_train, kelvins_test = load_dataset(kelvins_dir)
|
| 242 |
+
|
| 243 |
+
# Load Space-Track data if available
|
| 244 |
+
spacetrack_dir = data_dir / "cdm_spacetrack"
|
| 245 |
+
spacetrack_files = list(spacetrack_dir.glob("cdm_*.csv")) if spacetrack_dir.exists() else []
|
| 246 |
+
|
| 247 |
+
if not spacetrack_files:
|
| 248 |
+
print("\nNo Space-Track data found. Using Kelvins only.")
|
| 249 |
+
return kelvins_train, kelvins_test
|
| 250 |
+
|
| 251 |
+
# Load and merge all Space-Track CSVs
|
| 252 |
+
st_dfs = []
|
| 253 |
+
for f in spacetrack_files:
|
| 254 |
+
if f.name.startswith("checkpoint"):
|
| 255 |
+
continue
|
| 256 |
+
df = load_spacetrack_cdms(f)
|
| 257 |
+
df = group_into_events(df)
|
| 258 |
+
df = compute_relative_speed_from_excl_vol(df)
|
| 259 |
+
st_dfs.append(df)
|
| 260 |
+
|
| 261 |
+
if st_dfs:
|
| 262 |
+
all_st = pd.concat(st_dfs, ignore_index=True)
|
| 263 |
+
# Re-assign event IDs after concatenation
|
| 264 |
+
all_st = group_into_events(all_st)
|
| 265 |
+
merged_train = merge_datasets(kelvins_train, all_st)
|
| 266 |
+
else:
|
| 267 |
+
merged_train = kelvins_train
|
| 268 |
+
|
| 269 |
+
# Test set stays Kelvins-only for fair benchmarking
|
| 270 |
+
return merged_train, kelvins_test
|
src/data/sequence_builder.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code -- 2026-02-08
|
| 2 |
+
"""Build padded CDM sequences for the Temporal Fusion Transformer.
|
| 3 |
+
|
| 4 |
+
Each conjunction event is a variable-length time series of CDM snapshots.
|
| 5 |
+
This module handles:
|
| 6 |
+
- Selecting temporal vs static features
|
| 7 |
+
- Padding/truncating to fixed length
|
| 8 |
+
- Creating attention masks for padded positions
|
| 9 |
+
- Train/val/test splitting with stratification
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import torch
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
from sklearn.model_selection import train_test_split
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
# Maximum CDM sequence length (95th percentile of real data is ~25)
|
| 20 |
+
MAX_SEQ_LEN = 30
|
| 21 |
+
|
| 22 |
+
# Features that change with each CDM update (time-varying)
|
| 23 |
+
TEMPORAL_FEATURES = [
|
| 24 |
+
"miss_distance",
|
| 25 |
+
"relative_speed",
|
| 26 |
+
"relative_position_r", "relative_position_t", "relative_position_n",
|
| 27 |
+
"relative_velocity_r", "relative_velocity_t", "relative_velocity_n",
|
| 28 |
+
"max_risk_estimate", "max_risk_scaling",
|
| 29 |
+
# Target object covariance
|
| 30 |
+
"t_sigma_r", "t_sigma_t", "t_sigma_n",
|
| 31 |
+
"t_sigma_rdot", "t_sigma_tdot", "t_sigma_ndot",
|
| 32 |
+
# Chaser object covariance
|
| 33 |
+
"c_sigma_r", "c_sigma_t", "c_sigma_n",
|
| 34 |
+
"c_sigma_rdot", "c_sigma_tdot", "c_sigma_ndot",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
# Features that are constant per event (object properties)
|
| 38 |
+
STATIC_FEATURES = [
|
| 39 |
+
"t_h_apo", "t_h_per", "t_j2k_sma", "t_j2k_inc", "t_ecc",
|
| 40 |
+
"c_h_apo", "c_h_per", "c_j2k_sma", "c_j2k_inc", "c_ecc",
|
| 41 |
+
"t_span", "c_span",
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
# Orbital density features from CRASH Clock analysis (added by OrbitalDensityComputer)
|
| 45 |
+
DENSITY_FEATURES = [
|
| 46 |
+
"shell_density",
|
| 47 |
+
"shell_collision_rate",
|
| 48 |
+
"local_crash_clock_log",
|
| 49 |
+
"altitude_percentile",
|
| 50 |
+
"n_events_in_shell",
|
| 51 |
+
"shell_risk_rate",
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def find_available_features(df: pd.DataFrame, candidates: list[str]) -> list[str]:
|
| 56 |
+
"""Filter feature list to only columns that exist in the DataFrame."""
|
| 57 |
+
available = [c for c in candidates if c in df.columns]
|
| 58 |
+
missing = [c for c in candidates if c not in df.columns]
|
| 59 |
+
if missing:
|
| 60 |
+
print(f" Note: {len(missing)} features not in dataset, using {len(available)}")
|
| 61 |
+
return available
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class CDMSequenceDataset(Dataset):
|
| 65 |
+
"""
|
| 66 |
+
PyTorch Dataset that serves padded CDM sequences for the Transformer.
|
| 67 |
+
|
| 68 |
+
Each item contains:
|
| 69 |
+
- temporal_features: (S, F_t) tensor of time-varying CDM features
|
| 70 |
+
- static_features: (F_s,) tensor of object properties
|
| 71 |
+
- time_to_tca: (S, 1) tensor of time-to-closest-approach values
|
| 72 |
+
- mask: (S,) boolean mask (True = real data, False = padding)
|
| 73 |
+
- risk_label: scalar binary target
|
| 74 |
+
- miss_distance_log: scalar log1p(final_miss_distance) target
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
df: pd.DataFrame,
|
| 80 |
+
max_seq_len: int = MAX_SEQ_LEN,
|
| 81 |
+
temporal_cols: list[str] = None,
|
| 82 |
+
static_cols: list[str] = None,
|
| 83 |
+
):
|
| 84 |
+
self.max_seq_len = max_seq_len
|
| 85 |
+
|
| 86 |
+
# Find available features
|
| 87 |
+
self.temporal_cols = temporal_cols or find_available_features(df, TEMPORAL_FEATURES)
|
| 88 |
+
self.static_cols = static_cols or find_available_features(df, STATIC_FEATURES)
|
| 89 |
+
|
| 90 |
+
print(f" Temporal features: {len(self.temporal_cols)}")
|
| 91 |
+
print(f" Static features: {len(self.static_cols)}")
|
| 92 |
+
|
| 93 |
+
# Group by event_id
|
| 94 |
+
self.events = []
|
| 95 |
+
for event_id, group in df.groupby("event_id"):
|
| 96 |
+
# Sort by time_to_tca descending (first CDM = furthest from TCA)
|
| 97 |
+
group = group.sort_values("time_to_tca", ascending=False)
|
| 98 |
+
# Track data source for domain weighting
|
| 99 |
+
source = "kelvins"
|
| 100 |
+
if "source" in group.columns:
|
| 101 |
+
source = group["source"].iloc[0]
|
| 102 |
+
self.events.append({
|
| 103 |
+
"event_id": event_id,
|
| 104 |
+
"group": group,
|
| 105 |
+
"source": source,
|
| 106 |
+
})
|
| 107 |
+
|
| 108 |
+
# Compute global normalization stats from training data
|
| 109 |
+
self.temporal_mean = df[self.temporal_cols].mean().values.astype(np.float32)
|
| 110 |
+
self.temporal_std = df[self.temporal_cols].std().values.astype(np.float32)
|
| 111 |
+
self.temporal_std[self.temporal_std < 1e-8] = 1.0 # avoid div by zero
|
| 112 |
+
|
| 113 |
+
self.static_mean = df[self.static_cols].mean().values.astype(np.float32)
|
| 114 |
+
self.static_std = df[self.static_cols].std().values.astype(np.float32)
|
| 115 |
+
self.static_std[self.static_std < 1e-8] = 1.0
|
| 116 |
+
|
| 117 |
+
# Normalize time_to_tca
|
| 118 |
+
self.tca_mean = float(df["time_to_tca"].mean())
|
| 119 |
+
self.tca_std = float(df["time_to_tca"].std())
|
| 120 |
+
if self.tca_std < 1e-8:
|
| 121 |
+
self.tca_std = 1.0
|
| 122 |
+
|
| 123 |
+
# Compute delta normalization stats (approx from per-step differences)
|
| 124 |
+
# Deltas have different magnitude than raw features, need separate stats
|
| 125 |
+
self._compute_delta_stats(df)
|
| 126 |
+
|
| 127 |
+
def _compute_delta_stats(self, df: pd.DataFrame):
|
| 128 |
+
"""Estimate normalization stats for temporal first-order differences."""
|
| 129 |
+
# Sample a subset of events to estimate delta distributions
|
| 130 |
+
delta_samples = []
|
| 131 |
+
for _, group in df.groupby("event_id"):
|
| 132 |
+
if len(group) < 2:
|
| 133 |
+
continue
|
| 134 |
+
vals = group[self.temporal_cols].values.astype(np.float32)
|
| 135 |
+
vals = np.nan_to_num(vals, nan=0.0, posinf=0.0, neginf=0.0)
|
| 136 |
+
deltas = np.diff(vals, axis=0)
|
| 137 |
+
delta_samples.append(deltas)
|
| 138 |
+
if len(delta_samples) >= 2000: # cap for speed
|
| 139 |
+
break
|
| 140 |
+
if delta_samples:
|
| 141 |
+
all_deltas = np.concatenate(delta_samples, axis=0)
|
| 142 |
+
self.delta_mean = all_deltas.mean(axis=0).astype(np.float32)
|
| 143 |
+
self.delta_std = all_deltas.std(axis=0).astype(np.float32)
|
| 144 |
+
self.delta_std[self.delta_std < 1e-8] = 1.0
|
| 145 |
+
else:
|
| 146 |
+
n = len(self.temporal_cols)
|
| 147 |
+
self.delta_mean = np.zeros(n, dtype=np.float32)
|
| 148 |
+
self.delta_std = np.ones(n, dtype=np.float32)
|
| 149 |
+
|
| 150 |
+
def set_normalization(self, other: "CDMSequenceDataset"):
|
| 151 |
+
"""Copy normalization stats from another dataset (e.g., training set)."""
|
| 152 |
+
self.temporal_mean = other.temporal_mean
|
| 153 |
+
self.temporal_std = other.temporal_std
|
| 154 |
+
self.static_mean = other.static_mean
|
| 155 |
+
self.static_std = other.static_std
|
| 156 |
+
self.tca_mean = other.tca_mean
|
| 157 |
+
self.tca_std = other.tca_std
|
| 158 |
+
self.delta_mean = other.delta_mean
|
| 159 |
+
self.delta_std = other.delta_std
|
| 160 |
+
|
| 161 |
+
def __len__(self):
|
| 162 |
+
return len(self.events)
|
| 163 |
+
|
| 164 |
+
def __getitem__(self, idx):
|
| 165 |
+
event = self.events[idx]
|
| 166 |
+
group = event["group"]
|
| 167 |
+
|
| 168 |
+
# Extract temporal features: (seq_len, n_temporal)
|
| 169 |
+
temporal = group[self.temporal_cols].values.astype(np.float32)
|
| 170 |
+
temporal = np.nan_to_num(temporal, nan=0.0, posinf=0.0, neginf=0.0)
|
| 171 |
+
|
| 172 |
+
# Compute first-order differences (deltas) for temporal features
|
| 173 |
+
# This captures trends: is miss_distance shrinking? Is covariance tightening?
|
| 174 |
+
if len(temporal) > 1:
|
| 175 |
+
deltas = np.diff(temporal, axis=0) # (seq_len-1, n_temporal)
|
| 176 |
+
# Prepend zeros for the first timestep (no prior to diff against)
|
| 177 |
+
deltas = np.concatenate([np.zeros((1, deltas.shape[1]), dtype=np.float32), deltas], axis=0)
|
| 178 |
+
else:
|
| 179 |
+
deltas = np.zeros_like(temporal)
|
| 180 |
+
|
| 181 |
+
# Normalize raw features and deltas separately
|
| 182 |
+
temporal = (temporal - self.temporal_mean) / self.temporal_std
|
| 183 |
+
deltas = (deltas - self.delta_mean) / self.delta_std
|
| 184 |
+
|
| 185 |
+
# Concatenate: (seq_len, n_temporal * 2)
|
| 186 |
+
temporal = np.concatenate([temporal, deltas], axis=1)
|
| 187 |
+
|
| 188 |
+
# Extract static features from last row (they're constant per event)
|
| 189 |
+
static = group[self.static_cols].iloc[-1].values.astype(np.float32)
|
| 190 |
+
static = np.nan_to_num(static, nan=0.0, posinf=0.0, neginf=0.0)
|
| 191 |
+
|
| 192 |
+
# Time-to-TCA values: (seq_len, 1)
|
| 193 |
+
tca = group["time_to_tca"].values.astype(np.float32).reshape(-1, 1)
|
| 194 |
+
|
| 195 |
+
# Normalize
|
| 196 |
+
static = (static - self.static_mean) / self.static_std
|
| 197 |
+
tca = (tca - self.tca_mean) / self.tca_std
|
| 198 |
+
|
| 199 |
+
# Truncate or pad to max_seq_len
|
| 200 |
+
seq_len = len(temporal)
|
| 201 |
+
if seq_len > self.max_seq_len:
|
| 202 |
+
# Keep the most recent CDMs (closest to TCA = most informative)
|
| 203 |
+
temporal = temporal[-self.max_seq_len:]
|
| 204 |
+
tca = tca[-self.max_seq_len:]
|
| 205 |
+
seq_len = self.max_seq_len
|
| 206 |
+
|
| 207 |
+
# Pad (left-pad so the most recent CDM is always at position -1)
|
| 208 |
+
pad_len = self.max_seq_len - seq_len
|
| 209 |
+
if pad_len > 0:
|
| 210 |
+
temporal = np.pad(temporal, ((pad_len, 0), (0, 0)), constant_values=0)
|
| 211 |
+
tca = np.pad(tca, ((pad_len, 0), (0, 0)), constant_values=0)
|
| 212 |
+
|
| 213 |
+
# Attention mask: True for real positions, False for padding
|
| 214 |
+
mask = np.zeros(self.max_seq_len, dtype=bool)
|
| 215 |
+
mask[pad_len:] = True
|
| 216 |
+
|
| 217 |
+
# Target: risk label from final CDM's risk column
|
| 218 |
+
# risk > -5 means collision probability > 1e-5 (high risk)
|
| 219 |
+
final_risk = group["risk"].iloc[-1]
|
| 220 |
+
risk_label = 1.0 if final_risk > -5 else 0.0
|
| 221 |
+
|
| 222 |
+
# Target: log1p of final miss distance
|
| 223 |
+
final_miss = group["miss_distance"].iloc[-1] if "miss_distance" in group.columns else 0.0
|
| 224 |
+
miss_log = np.log1p(max(final_miss, 0.0))
|
| 225 |
+
|
| 226 |
+
# Target: log10(Pc) — the Kelvins `risk` column is already log10(Pc).
|
| 227 |
+
# Clamp to [-20, 0] (Pc ranges from ~1e-20 to ~1)
|
| 228 |
+
pc_log10 = float(max(min(final_risk, 0.0), -20.0))
|
| 229 |
+
|
| 230 |
+
# Domain weight: Kelvins events get full weight, Space-Track events
|
| 231 |
+
# get reduced weight since they have sparse features (16 vs 103 columns).
|
| 232 |
+
# This prevents the model from learning shortcuts on zero-padded features.
|
| 233 |
+
source = event.get("source", "kelvins")
|
| 234 |
+
domain_weight = 1.0 if source == "kelvins" else 0.3
|
| 235 |
+
|
| 236 |
+
return {
|
| 237 |
+
"temporal": torch.tensor(temporal, dtype=torch.float32),
|
| 238 |
+
"static": torch.tensor(static, dtype=torch.float32),
|
| 239 |
+
"time_to_tca": torch.tensor(tca, dtype=torch.float32),
|
| 240 |
+
"mask": torch.tensor(mask, dtype=torch.bool),
|
| 241 |
+
"risk_label": torch.tensor(risk_label, dtype=torch.float32),
|
| 242 |
+
"miss_log": torch.tensor(miss_log, dtype=torch.float32),
|
| 243 |
+
"pc_log10": torch.tensor(pc_log10, dtype=torch.float32),
|
| 244 |
+
"domain_weight": torch.tensor(domain_weight, dtype=torch.float32),
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class PretrainDataset(Dataset):
|
| 249 |
+
"""Simplified CDM dataset for self-supervised pre-training (no labels needed).
|
| 250 |
+
|
| 251 |
+
Returns only temporal features, static features, time_to_tca, and mask.
|
| 252 |
+
Can process combined train+test data since labels aren't used.
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
def __init__(
|
| 256 |
+
self,
|
| 257 |
+
df: pd.DataFrame,
|
| 258 |
+
max_seq_len: int = MAX_SEQ_LEN,
|
| 259 |
+
temporal_cols: list[str] = None,
|
| 260 |
+
static_cols: list[str] = None,
|
| 261 |
+
):
|
| 262 |
+
self.max_seq_len = max_seq_len
|
| 263 |
+
|
| 264 |
+
self.temporal_cols = temporal_cols or find_available_features(df, TEMPORAL_FEATURES)
|
| 265 |
+
self.static_cols = static_cols or find_available_features(df, STATIC_FEATURES)
|
| 266 |
+
|
| 267 |
+
print(f" PretrainDataset — Temporal: {len(self.temporal_cols)}, Static: {len(self.static_cols)}")
|
| 268 |
+
|
| 269 |
+
# Group by event_id
|
| 270 |
+
self.events = []
|
| 271 |
+
for event_id, group in df.groupby("event_id"):
|
| 272 |
+
group = group.sort_values("time_to_tca", ascending=False)
|
| 273 |
+
self.events.append({"event_id": event_id, "group": group})
|
| 274 |
+
|
| 275 |
+
# Compute global normalization stats
|
| 276 |
+
self.temporal_mean = df[self.temporal_cols].mean().values.astype(np.float32)
|
| 277 |
+
self.temporal_std = df[self.temporal_cols].std().values.astype(np.float32)
|
| 278 |
+
self.temporal_std[self.temporal_std < 1e-8] = 1.0
|
| 279 |
+
|
| 280 |
+
self.static_mean = df[self.static_cols].mean().values.astype(np.float32)
|
| 281 |
+
self.static_std = df[self.static_cols].std().values.astype(np.float32)
|
| 282 |
+
self.static_std[self.static_std < 1e-8] = 1.0
|
| 283 |
+
|
| 284 |
+
self.tca_mean = float(df["time_to_tca"].mean())
|
| 285 |
+
self.tca_std = float(df["time_to_tca"].std())
|
| 286 |
+
if self.tca_std < 1e-8:
|
| 287 |
+
self.tca_std = 1.0
|
| 288 |
+
|
| 289 |
+
self._compute_delta_stats(df)
|
| 290 |
+
|
| 291 |
+
def _compute_delta_stats(self, df: pd.DataFrame):
|
| 292 |
+
"""Estimate normalization stats for temporal first-order differences."""
|
| 293 |
+
delta_samples = []
|
| 294 |
+
for _, group in df.groupby("event_id"):
|
| 295 |
+
if len(group) < 2:
|
| 296 |
+
continue
|
| 297 |
+
vals = group[self.temporal_cols].values.astype(np.float32)
|
| 298 |
+
vals = np.nan_to_num(vals, nan=0.0, posinf=0.0, neginf=0.0)
|
| 299 |
+
deltas = np.diff(vals, axis=0)
|
| 300 |
+
delta_samples.append(deltas)
|
| 301 |
+
if len(delta_samples) >= 2000:
|
| 302 |
+
break
|
| 303 |
+
if delta_samples:
|
| 304 |
+
all_deltas = np.concatenate(delta_samples, axis=0)
|
| 305 |
+
self.delta_mean = all_deltas.mean(axis=0).astype(np.float32)
|
| 306 |
+
self.delta_std = all_deltas.std(axis=0).astype(np.float32)
|
| 307 |
+
self.delta_std[self.delta_std < 1e-8] = 1.0
|
| 308 |
+
else:
|
| 309 |
+
n = len(self.temporal_cols)
|
| 310 |
+
self.delta_mean = np.zeros(n, dtype=np.float32)
|
| 311 |
+
self.delta_std = np.ones(n, dtype=np.float32)
|
| 312 |
+
|
| 313 |
+
def set_normalization(self, other):
|
| 314 |
+
"""Copy normalization stats from another dataset."""
|
| 315 |
+
self.temporal_mean = other.temporal_mean
|
| 316 |
+
self.temporal_std = other.temporal_std
|
| 317 |
+
self.static_mean = other.static_mean
|
| 318 |
+
self.static_std = other.static_std
|
| 319 |
+
self.tca_mean = other.tca_mean
|
| 320 |
+
self.tca_std = other.tca_std
|
| 321 |
+
self.delta_mean = other.delta_mean
|
| 322 |
+
self.delta_std = other.delta_std
|
| 323 |
+
|
| 324 |
+
def __len__(self):
|
| 325 |
+
return len(self.events)
|
| 326 |
+
|
| 327 |
+
def __getitem__(self, idx):
|
| 328 |
+
event = self.events[idx]
|
| 329 |
+
group = event["group"]
|
| 330 |
+
|
| 331 |
+
# Extract temporal features
|
| 332 |
+
temporal = group[self.temporal_cols].values.astype(np.float32)
|
| 333 |
+
temporal = np.nan_to_num(temporal, nan=0.0, posinf=0.0, neginf=0.0)
|
| 334 |
+
|
| 335 |
+
# Compute first-order differences
|
| 336 |
+
if len(temporal) > 1:
|
| 337 |
+
deltas = np.diff(temporal, axis=0)
|
| 338 |
+
deltas = np.concatenate([np.zeros((1, deltas.shape[1]), dtype=np.float32), deltas], axis=0)
|
| 339 |
+
else:
|
| 340 |
+
deltas = np.zeros_like(temporal)
|
| 341 |
+
|
| 342 |
+
# Normalize
|
| 343 |
+
temporal = (temporal - self.temporal_mean) / self.temporal_std
|
| 344 |
+
deltas = (deltas - self.delta_mean) / self.delta_std
|
| 345 |
+
temporal = np.concatenate([temporal, deltas], axis=1)
|
| 346 |
+
|
| 347 |
+
# Static features
|
| 348 |
+
static = group[self.static_cols].iloc[-1].values.astype(np.float32)
|
| 349 |
+
static = np.nan_to_num(static, nan=0.0, posinf=0.0, neginf=0.0)
|
| 350 |
+
|
| 351 |
+
# Time-to-TCA
|
| 352 |
+
tca = group["time_to_tca"].values.astype(np.float32).reshape(-1, 1)
|
| 353 |
+
|
| 354 |
+
static = (static - self.static_mean) / self.static_std
|
| 355 |
+
tca = (tca - self.tca_mean) / self.tca_std
|
| 356 |
+
|
| 357 |
+
# Truncate or pad
|
| 358 |
+
seq_len = len(temporal)
|
| 359 |
+
if seq_len > self.max_seq_len:
|
| 360 |
+
temporal = temporal[-self.max_seq_len:]
|
| 361 |
+
tca = tca[-self.max_seq_len:]
|
| 362 |
+
seq_len = self.max_seq_len
|
| 363 |
+
|
| 364 |
+
pad_len = self.max_seq_len - seq_len
|
| 365 |
+
if pad_len > 0:
|
| 366 |
+
temporal = np.pad(temporal, ((pad_len, 0), (0, 0)), constant_values=0)
|
| 367 |
+
tca = np.pad(tca, ((pad_len, 0), (0, 0)), constant_values=0)
|
| 368 |
+
|
| 369 |
+
mask = np.zeros(self.max_seq_len, dtype=bool)
|
| 370 |
+
mask[pad_len:] = True
|
| 371 |
+
|
| 372 |
+
return {
|
| 373 |
+
"temporal": torch.tensor(temporal, dtype=torch.float32),
|
| 374 |
+
"static": torch.tensor(static, dtype=torch.float32),
|
| 375 |
+
"time_to_tca": torch.tensor(tca, dtype=torch.float32),
|
| 376 |
+
"mask": torch.tensor(mask, dtype=torch.bool),
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def build_datasets(
|
| 381 |
+
train_df: pd.DataFrame,
|
| 382 |
+
test_df: pd.DataFrame,
|
| 383 |
+
val_fraction: float = 0.1,
|
| 384 |
+
use_density: bool = False,
|
| 385 |
+
cal_fraction: float = 0.0,
|
| 386 |
+
) -> tuple:
|
| 387 |
+
"""
|
| 388 |
+
Build train, validation, and test datasets with shared normalization.
|
| 389 |
+
|
| 390 |
+
Splits training data into train + val by event_id (stratified by risk).
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
train_df: Training CDM DataFrame
|
| 394 |
+
test_df: Test CDM DataFrame
|
| 395 |
+
val_fraction: Fraction of Kelvins training events for validation
|
| 396 |
+
use_density: If True, include DENSITY_FEATURES in static features
|
| 397 |
+
cal_fraction: If > 0, further split validation into val + calibration
|
| 398 |
+
for conformal prediction. Returns 4-tuple instead of 3.
|
| 399 |
+
|
| 400 |
+
Returns:
|
| 401 |
+
If cal_fraction == 0: (train_ds, val_ds, test_ds)
|
| 402 |
+
If cal_fraction > 0: (train_ds, val_ds, cal_ds, test_ds)
|
| 403 |
+
"""
|
| 404 |
+
# Compute density features if requested
|
| 405 |
+
if use_density:
|
| 406 |
+
from src.data.density_features import OrbitalDensityComputer
|
| 407 |
+
density_computer = OrbitalDensityComputer()
|
| 408 |
+
density_computer.fit(train_df)
|
| 409 |
+
train_df = density_computer.transform(train_df)
|
| 410 |
+
test_df = density_computer.transform(test_df)
|
| 411 |
+
else:
|
| 412 |
+
density_computer = None
|
| 413 |
+
|
| 414 |
+
# Static columns: base (filtered to available) + optional density
|
| 415 |
+
static_cols = [c for c in STATIC_FEATURES if c in train_df.columns]
|
| 416 |
+
if use_density:
|
| 417 |
+
static_cols = static_cols + [
|
| 418 |
+
f for f in DENSITY_FEATURES if f in train_df.columns
|
| 419 |
+
]
|
| 420 |
+
|
| 421 |
+
# Determine risk label per event for stratification
|
| 422 |
+
has_source = "source" in train_df.columns
|
| 423 |
+
agg_dict = {"risk": ("risk", "last")}
|
| 424 |
+
if has_source:
|
| 425 |
+
agg_dict["source"] = ("source", "first")
|
| 426 |
+
event_meta = train_df.groupby("event_id").agg(**agg_dict).reset_index()
|
| 427 |
+
event_meta["label"] = (event_meta["risk"] > -5).astype(int)
|
| 428 |
+
|
| 429 |
+
# Split validation from KELVINS-ONLY events for fair model selection.
|
| 430 |
+
# Space-Track events (sparse features, all high-risk) inflate val metrics.
|
| 431 |
+
if has_source:
|
| 432 |
+
kelvins_events = event_meta[event_meta["source"] == "kelvins"]
|
| 433 |
+
other_events = event_meta[event_meta["source"] != "kelvins"]
|
| 434 |
+
|
| 435 |
+
kelvins_ids = kelvins_events["event_id"].values
|
| 436 |
+
kelvins_labels = kelvins_events["label"].values
|
| 437 |
+
|
| 438 |
+
# Stratified split on Kelvins events only
|
| 439 |
+
k_train_ids, val_ids = train_test_split(
|
| 440 |
+
kelvins_ids, test_size=val_fraction, stratify=kelvins_labels, random_state=42
|
| 441 |
+
)
|
| 442 |
+
# Training = Kelvins train split + all Space-Track events
|
| 443 |
+
train_ids = np.concatenate([k_train_ids, other_events["event_id"].values])
|
| 444 |
+
else:
|
| 445 |
+
event_ids = event_meta["event_id"].values
|
| 446 |
+
labels = event_meta["label"].values
|
| 447 |
+
train_ids, val_ids = train_test_split(
|
| 448 |
+
event_ids, test_size=val_fraction, stratify=labels, random_state=42
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# Further split validation into val + calibration for conformal prediction
|
| 452 |
+
cal_ids = np.array([])
|
| 453 |
+
if cal_fraction > 0 and len(val_ids) > 20:
|
| 454 |
+
val_labels = event_meta[event_meta["event_id"].isin(val_ids)]["label"].values
|
| 455 |
+
val_ids_arr = val_ids
|
| 456 |
+
val_ids, cal_ids = train_test_split(
|
| 457 |
+
val_ids_arr,
|
| 458 |
+
test_size=cal_fraction,
|
| 459 |
+
stratify=val_labels,
|
| 460 |
+
random_state=123, # different seed from train/val split
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
train_sub = train_df[train_df["event_id"].isin(train_ids)]
|
| 464 |
+
val_sub = train_df[train_df["event_id"].isin(val_ids)]
|
| 465 |
+
|
| 466 |
+
print(f"Building datasets:")
|
| 467 |
+
print(f" Train events: {len(train_ids)}")
|
| 468 |
+
if has_source:
|
| 469 |
+
n_k = train_sub[train_sub["source"] == "kelvins"]["event_id"].nunique()
|
| 470 |
+
n_s = train_sub[train_sub["source"] != "kelvins"]["event_id"].nunique()
|
| 471 |
+
print(f" (Kelvins: {n_k}, Space-Track: {n_s})")
|
| 472 |
+
if use_density:
|
| 473 |
+
print(f" Static features: {len(static_cols)} (base: {len(STATIC_FEATURES)}, "
|
| 474 |
+
f"density: {len(static_cols) - len(STATIC_FEATURES)})")
|
| 475 |
+
|
| 476 |
+
train_ds = CDMSequenceDataset(train_sub, static_cols=static_cols)
|
| 477 |
+
|
| 478 |
+
print(f" Val events: {len(val_ids)} (Kelvins-only)")
|
| 479 |
+
val_ds = CDMSequenceDataset(val_sub, static_cols=static_cols)
|
| 480 |
+
val_ds.set_normalization(train_ds) # use training stats
|
| 481 |
+
|
| 482 |
+
print(f" Test events: {test_df['event_id'].nunique()}")
|
| 483 |
+
test_ds = CDMSequenceDataset(test_df, temporal_cols=train_ds.temporal_cols, static_cols=static_cols)
|
| 484 |
+
test_ds.set_normalization(train_ds)
|
| 485 |
+
|
| 486 |
+
# Store density computer on train_ds for checkpoint saving
|
| 487 |
+
if density_computer is not None:
|
| 488 |
+
train_ds._density_computer = density_computer
|
| 489 |
+
|
| 490 |
+
if cal_fraction > 0 and len(cal_ids) > 0:
|
| 491 |
+
cal_sub = train_df[train_df["event_id"].isin(cal_ids)]
|
| 492 |
+
print(f" Cal events: {len(cal_ids)} (for conformal prediction)")
|
| 493 |
+
cal_ds = CDMSequenceDataset(cal_sub, static_cols=static_cols)
|
| 494 |
+
cal_ds.set_normalization(train_ds)
|
| 495 |
+
return train_ds, val_ds, cal_ds, test_ds
|
| 496 |
+
|
| 497 |
+
return train_ds, val_ds, test_ds
|
src/data/spacetrack_crossref.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cross-reference detected maneuvers with Space-Track.org CDM data.
|
| 2 |
+
|
| 3 |
+
Queries the CDM_PUBLIC class for recent conjunction data messages
|
| 4 |
+
involving maneuvered satellites. CDM confirmation is the strongest
|
| 5 |
+
signal that a maneuver was collision-avoidance.
|
| 6 |
+
|
| 7 |
+
Requires SPACETRACK_USER and SPACETRACK_PASS environment variables.
|
| 8 |
+
Fails silently if credentials are not set (purely enrichment).
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import json
|
| 13 |
+
import time
|
| 14 |
+
import requests
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from datetime import datetime, timedelta, timezone
|
| 17 |
+
|
| 18 |
+
# Rate limiting: max 30 requests/min to Space-Track
|
| 19 |
+
MAX_REQUESTS_PER_MIN = 30
|
| 20 |
+
BATCH_SIZE = 100 # Max NORAD IDs per query
|
| 21 |
+
CACHE_EXPIRY_DAYS = 7
|
| 22 |
+
|
| 23 |
+
SPACETRACK_BASE = "https://www.space-track.org"
|
| 24 |
+
LOGIN_URL = f"{SPACETRACK_BASE}/ajaxauth/login"
|
| 25 |
+
CDM_QUERY_URL = f"{SPACETRACK_BASE}/basicspacedata/query/class/cdm_public"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _get_credentials() -> tuple[str, str]:
|
| 29 |
+
"""Get Space-Track credentials from environment."""
|
| 30 |
+
user = os.environ.get("SPACETRACK_USER", "")
|
| 31 |
+
passwd = os.environ.get("SPACETRACK_PASS", "")
|
| 32 |
+
return user, passwd
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _load_cache(cache_path: Path) -> dict:
|
| 36 |
+
"""Load CDM cache, filtering expired entries."""
|
| 37 |
+
if not cache_path.exists():
|
| 38 |
+
return {}
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
with open(cache_path) as f:
|
| 42 |
+
cache = json.load(f)
|
| 43 |
+
except (json.JSONDecodeError, IOError):
|
| 44 |
+
return {}
|
| 45 |
+
|
| 46 |
+
# Filter expired entries
|
| 47 |
+
cutoff = (datetime.now(timezone.utc) - timedelta(days=CACHE_EXPIRY_DAYS)).isoformat()
|
| 48 |
+
return {
|
| 49 |
+
k: v for k, v in cache.items()
|
| 50 |
+
if v.get("cached_at", "") > cutoff
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _save_cache(cache: dict, cache_path: Path):
|
| 55 |
+
"""Save CDM cache to disk."""
|
| 56 |
+
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
| 57 |
+
with open(cache_path, "w") as f:
|
| 58 |
+
json.dump(cache, f, indent=2)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def check_cdm_for_norad_ids(
|
| 62 |
+
norad_ids: list[int],
|
| 63 |
+
lookback_days: int = 7,
|
| 64 |
+
min_pc: float = 1e-7,
|
| 65 |
+
cache_dir: Path = None,
|
| 66 |
+
) -> dict[int, list[dict]]:
|
| 67 |
+
"""Query Space-Track CDM_PUBLIC for recent CDMs involving given satellites.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
norad_ids: NORAD catalog IDs to check.
|
| 71 |
+
lookback_days: How far back to search for CDMs.
|
| 72 |
+
min_pc: Minimum probability of collision to include.
|
| 73 |
+
cache_dir: Directory for CDM cache file. Defaults to data/prediction_logs/.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Map of norad_id -> list of CDM records with PC, TCA, MISS_DISTANCE.
|
| 77 |
+
Empty dict if credentials not set or query fails.
|
| 78 |
+
"""
|
| 79 |
+
user, passwd = _get_credentials()
|
| 80 |
+
if not user or not passwd:
|
| 81 |
+
return {}
|
| 82 |
+
|
| 83 |
+
if cache_dir is None:
|
| 84 |
+
cache_dir = Path(__file__).parent.parent.parent / "data" / "prediction_logs"
|
| 85 |
+
|
| 86 |
+
cache_path = cache_dir / "cdm_cache.json"
|
| 87 |
+
cache = _load_cache(cache_path)
|
| 88 |
+
|
| 89 |
+
# Check which IDs need fresh queries
|
| 90 |
+
results = {}
|
| 91 |
+
uncached_ids = []
|
| 92 |
+
|
| 93 |
+
for nid in norad_ids:
|
| 94 |
+
key = str(nid)
|
| 95 |
+
if key in cache:
|
| 96 |
+
results[nid] = cache[key].get("cdms", [])
|
| 97 |
+
else:
|
| 98 |
+
uncached_ids.append(nid)
|
| 99 |
+
|
| 100 |
+
if not uncached_ids:
|
| 101 |
+
return results
|
| 102 |
+
|
| 103 |
+
# Authenticate with Space-Track
|
| 104 |
+
try:
|
| 105 |
+
session = requests.Session()
|
| 106 |
+
resp = session.post(LOGIN_URL, data={
|
| 107 |
+
"identity": user,
|
| 108 |
+
"password": passwd,
|
| 109 |
+
}, timeout=30)
|
| 110 |
+
resp.raise_for_status()
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f" Space-Track login failed: {e}")
|
| 113 |
+
return results
|
| 114 |
+
|
| 115 |
+
# Query in batches
|
| 116 |
+
now_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
| 117 |
+
lookback_str = (datetime.now(timezone.utc) - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
|
| 118 |
+
|
| 119 |
+
for batch_start in range(0, len(uncached_ids), BATCH_SIZE):
|
| 120 |
+
batch = uncached_ids[batch_start:batch_start + BATCH_SIZE]
|
| 121 |
+
ids_str = ",".join(str(nid) for nid in batch)
|
| 122 |
+
|
| 123 |
+
query_url = (
|
| 124 |
+
f"{CDM_QUERY_URL}"
|
| 125 |
+
f"/SAT1_NORAD_CAT_ID/{ids_str}"
|
| 126 |
+
f"/TCA/>{lookback_str}"
|
| 127 |
+
f"/orderby/TCA desc"
|
| 128 |
+
f"/format/json"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
try:
|
| 132 |
+
resp = session.get(query_url, timeout=60)
|
| 133 |
+
resp.raise_for_status()
|
| 134 |
+
cdm_records = resp.json()
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f" Space-Track CDM query failed: {e}")
|
| 137 |
+
# Cache empty results for failed IDs to avoid re-querying
|
| 138 |
+
for nid in batch:
|
| 139 |
+
cache[str(nid)] = {
|
| 140 |
+
"cdms": [],
|
| 141 |
+
"cached_at": datetime.now(timezone.utc).isoformat(),
|
| 142 |
+
}
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
# Process CDM records
|
| 146 |
+
batch_results: dict[int, list[dict]] = {nid: [] for nid in batch}
|
| 147 |
+
|
| 148 |
+
for cdm in cdm_records:
|
| 149 |
+
try:
|
| 150 |
+
pc = float(cdm.get("PC", 0) or 0)
|
| 151 |
+
if pc < min_pc:
|
| 152 |
+
continue
|
| 153 |
+
|
| 154 |
+
sat1_id = int(cdm.get("SAT1_NORAD_CAT_ID", 0))
|
| 155 |
+
record = {
|
| 156 |
+
"tca": cdm.get("TCA", ""),
|
| 157 |
+
"pc": pc,
|
| 158 |
+
"miss_distance_km": float(cdm.get("MISS_DISTANCE", 0) or 0) / 1000.0,
|
| 159 |
+
"sat1_name": cdm.get("SAT1_NAME", ""),
|
| 160 |
+
"sat2_name": cdm.get("SAT2_NAME", ""),
|
| 161 |
+
"sat2_norad": int(cdm.get("SAT2_NORAD_CAT_ID", 0) or 0),
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
if sat1_id in batch_results:
|
| 165 |
+
batch_results[sat1_id].append(record)
|
| 166 |
+
except (ValueError, TypeError):
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
# Update cache and results
|
| 170 |
+
for nid in batch:
|
| 171 |
+
cdms = batch_results.get(nid, [])
|
| 172 |
+
results[nid] = cdms
|
| 173 |
+
cache[str(nid)] = {
|
| 174 |
+
"cdms": cdms,
|
| 175 |
+
"cached_at": datetime.now(timezone.utc).isoformat(),
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
# Rate limiting between batches
|
| 179 |
+
if batch_start + BATCH_SIZE < len(uncached_ids):
|
| 180 |
+
time.sleep(60.0 / MAX_REQUESTS_PER_MIN)
|
| 181 |
+
|
| 182 |
+
# Save updated cache
|
| 183 |
+
_save_cache(cache, cache_path)
|
| 184 |
+
|
| 185 |
+
return results
|
src/evaluation/__init__.py
ADDED
|
File without changes
|
src/evaluation/conformal.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code — 2026-02-13
|
| 2 |
+
"""Conformal prediction for calibrated risk bounds.
|
| 3 |
+
|
| 4 |
+
Provides distribution-free prediction sets with guaranteed marginal coverage:
|
| 5 |
+
P(true_label ∈ prediction_set) ≥ 1 - alpha
|
| 6 |
+
|
| 7 |
+
This directly addresses NASA CARA's criticism about uncertainty quantification
|
| 8 |
+
in ML-based collision risk assessment. Instead of a single probability, we
|
| 9 |
+
output a prediction set (e.g., {LOW, MODERATE}) that provably covers the
|
| 10 |
+
true risk tier at the specified confidence level.
|
| 11 |
+
|
| 12 |
+
Method: Split conformal prediction (Vovk et al. 2005, Lei et al. 2018)
|
| 13 |
+
- Calibrate on a held-out set separate from training AND model selection
|
| 14 |
+
- Compute nonconformity scores
|
| 15 |
+
- Use quantile of calibration scores to construct prediction sets at test time
|
| 16 |
+
|
| 17 |
+
References:
|
| 18 |
+
- Vovk, Gammerman, Shafer (2005) "Algorithmic Learning in a Random World"
|
| 19 |
+
- Lei et al. (2018) "Distribution-Free Predictive Inference for Regression"
|
| 20 |
+
- Angelopoulos & Bates (2021) "A Gentle Introduction to Conformal Prediction"
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class ConformalResult:
|
| 29 |
+
"""Result of conformal prediction for a single example."""
|
| 30 |
+
prediction_set: list[str] # e.g., ["LOW", "MODERATE"]
|
| 31 |
+
set_size: int # |prediction_set|
|
| 32 |
+
risk_prob: float # raw model probability
|
| 33 |
+
lower_bound: float # lower probability bound
|
| 34 |
+
upper_bound: float # upper probability bound
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ConformalPredictor:
|
| 38 |
+
"""Split conformal prediction for binary risk classification.
|
| 39 |
+
|
| 40 |
+
Workflow:
|
| 41 |
+
1. Train model on training set
|
| 42 |
+
2. Select model (early stopping) on validation set
|
| 43 |
+
3. calibrate() on a SEPARATE calibration set (held out from validation)
|
| 44 |
+
4. predict() on test data with coverage guarantee
|
| 45 |
+
|
| 46 |
+
The calibration set must NOT be used for training or model selection,
|
| 47 |
+
otherwise the coverage guarantee is invalidated.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
# Risk tiers with thresholds
|
| 51 |
+
TIERS = {
|
| 52 |
+
"LOW": (0.0, 0.10),
|
| 53 |
+
"MODERATE": (0.10, 0.40),
|
| 54 |
+
"HIGH": (0.40, 0.70),
|
| 55 |
+
"CRITICAL": (0.70, 1.0),
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
def __init__(self):
|
| 59 |
+
self.quantile_lower = None # q_hat for lower bound
|
| 60 |
+
self.quantile_upper = None # q_hat for upper bound
|
| 61 |
+
self.alpha = None
|
| 62 |
+
self.n_cal = 0
|
| 63 |
+
self.is_calibrated = False
|
| 64 |
+
|
| 65 |
+
def calibrate(
|
| 66 |
+
self,
|
| 67 |
+
cal_probs: np.ndarray,
|
| 68 |
+
cal_labels: np.ndarray,
|
| 69 |
+
alpha: float = 0.10,
|
| 70 |
+
) -> dict:
|
| 71 |
+
"""Calibrate conformal predictor on held-out calibration set.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
cal_probs: Model predicted probabilities on calibration set, shape (n,)
|
| 75 |
+
cal_labels: True binary labels on calibration set, shape (n,)
|
| 76 |
+
alpha: Desired miscoverage rate. 1-alpha = coverage level.
|
| 77 |
+
alpha=0.10 → 90% coverage guarantee.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Calibration summary dict with quantiles and statistics
|
| 81 |
+
"""
|
| 82 |
+
n = len(cal_probs)
|
| 83 |
+
if n < 10:
|
| 84 |
+
raise ValueError(f"Calibration set too small: {n} examples (need >= 10)")
|
| 85 |
+
|
| 86 |
+
self.alpha = alpha
|
| 87 |
+
self.n_cal = n
|
| 88 |
+
|
| 89 |
+
# Nonconformity score: how "wrong" is the model on each calibration example?
|
| 90 |
+
# For binary classification with probabilities:
|
| 91 |
+
# score = 1 - P(true class)
|
| 92 |
+
# High score = model is wrong/uncertain
|
| 93 |
+
scores = np.where(
|
| 94 |
+
cal_labels == 1,
|
| 95 |
+
1.0 - cal_probs, # positive: score = 1 - P(positive)
|
| 96 |
+
cal_probs, # negative: score = P(positive) = 1 - P(negative)
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Conformal quantile: includes finite-sample correction
|
| 100 |
+
# q_hat = ceil((n+1)(1-alpha))/n -th quantile of scores
|
| 101 |
+
adjusted_level = np.ceil((n + 1) * (1 - alpha)) / n
|
| 102 |
+
adjusted_level = min(adjusted_level, 1.0)
|
| 103 |
+
self.q_hat = float(np.quantile(scores, adjusted_level))
|
| 104 |
+
|
| 105 |
+
# For prediction intervals on the probability itself:
|
| 106 |
+
# We also compute quantiles for constructing upper/lower prob bounds
|
| 107 |
+
# Using calibration residuals: |P(positive) - is_positive|
|
| 108 |
+
residuals = np.abs(cal_probs - cal_labels.astype(float))
|
| 109 |
+
self.q_residual = float(np.quantile(residuals, adjusted_level))
|
| 110 |
+
|
| 111 |
+
self.is_calibrated = True
|
| 112 |
+
|
| 113 |
+
# Report calibration statistics
|
| 114 |
+
empirical_coverage = np.mean(scores <= self.q_hat)
|
| 115 |
+
|
| 116 |
+
summary = {
|
| 117 |
+
"alpha": alpha,
|
| 118 |
+
"target_coverage": 1 - alpha,
|
| 119 |
+
"n_calibration": n,
|
| 120 |
+
"q_hat": self.q_hat,
|
| 121 |
+
"q_residual": self.q_residual,
|
| 122 |
+
"empirical_coverage_cal": float(empirical_coverage),
|
| 123 |
+
"mean_score": float(scores.mean()),
|
| 124 |
+
"median_score": float(np.median(scores)),
|
| 125 |
+
"cal_pos_rate": float(cal_labels.mean()),
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
print(f" Conformal calibration (alpha={alpha}):")
|
| 129 |
+
print(f" Calibration set: {n} examples ({cal_labels.sum():.0f} positive)")
|
| 130 |
+
print(f" q_hat (nonconformity): {self.q_hat:.4f}")
|
| 131 |
+
print(f" q_residual: {self.q_residual:.4f}")
|
| 132 |
+
print(f" Empirical coverage (cal): {empirical_coverage:.4f}")
|
| 133 |
+
|
| 134 |
+
return summary
|
| 135 |
+
|
| 136 |
+
def predict(self, test_probs: np.ndarray) -> list[ConformalResult]:
|
| 137 |
+
"""Produce conformal prediction sets for test examples.
|
| 138 |
+
|
| 139 |
+
For each test example, returns:
|
| 140 |
+
- Prediction set: set of risk tiers that could contain the true risk
|
| 141 |
+
- Probability bounds: [lower, upper] interval on the true probability
|
| 142 |
+
|
| 143 |
+
Coverage guarantee: P(true_tier ∈ prediction_set) ≥ 1 - alpha
|
| 144 |
+
"""
|
| 145 |
+
if not self.is_calibrated:
|
| 146 |
+
raise RuntimeError("Must call calibrate() before predict()")
|
| 147 |
+
|
| 148 |
+
results = []
|
| 149 |
+
for p in test_probs:
|
| 150 |
+
# Probability bounds from residual quantile
|
| 151 |
+
lower = max(0.0, p - self.q_residual)
|
| 152 |
+
upper = min(1.0, p + self.q_residual)
|
| 153 |
+
|
| 154 |
+
# Prediction set: all tiers that overlap with [lower, upper]
|
| 155 |
+
pred_set = []
|
| 156 |
+
for tier_name, (tier_lo, tier_hi) in self.TIERS.items():
|
| 157 |
+
if lower < tier_hi and upper > tier_lo:
|
| 158 |
+
pred_set.append(tier_name)
|
| 159 |
+
|
| 160 |
+
results.append(ConformalResult(
|
| 161 |
+
prediction_set=pred_set,
|
| 162 |
+
set_size=len(pred_set),
|
| 163 |
+
risk_prob=float(p),
|
| 164 |
+
lower_bound=lower,
|
| 165 |
+
upper_bound=upper,
|
| 166 |
+
))
|
| 167 |
+
|
| 168 |
+
return results
|
| 169 |
+
|
| 170 |
+
def evaluate(
|
| 171 |
+
self,
|
| 172 |
+
test_probs: np.ndarray,
|
| 173 |
+
test_labels: np.ndarray,
|
| 174 |
+
) -> dict:
|
| 175 |
+
"""Evaluate conformal prediction on test set.
|
| 176 |
+
|
| 177 |
+
Reports:
|
| 178 |
+
- Marginal coverage: fraction of test examples where true label
|
| 179 |
+
falls within prediction set
|
| 180 |
+
- Average set size: how informative are the predictions
|
| 181 |
+
- Coverage by tier: per-tier coverage (conditional coverage)
|
| 182 |
+
- Efficiency: 1 - (avg_set_size / n_tiers)
|
| 183 |
+
"""
|
| 184 |
+
if not self.is_calibrated:
|
| 185 |
+
raise RuntimeError("Must call calibrate() before evaluate()")
|
| 186 |
+
|
| 187 |
+
results = self.predict(test_probs)
|
| 188 |
+
|
| 189 |
+
# Map labels to tiers for coverage check
|
| 190 |
+
def label_to_tier(prob: float) -> str:
|
| 191 |
+
for tier_name, (lo, hi) in self.TIERS.items():
|
| 192 |
+
if lo <= prob < hi:
|
| 193 |
+
return tier_name
|
| 194 |
+
return "CRITICAL" # prob == 1.0
|
| 195 |
+
|
| 196 |
+
# True "tier" based on actual probability (binary: 0 or 1)
|
| 197 |
+
true_tiers = [label_to_tier(float(l)) for l in test_labels]
|
| 198 |
+
|
| 199 |
+
# Marginal coverage: does the prediction set contain the true tier?
|
| 200 |
+
covered = [
|
| 201 |
+
true_tier in result.prediction_set
|
| 202 |
+
for true_tier, result in zip(true_tiers, results)
|
| 203 |
+
]
|
| 204 |
+
marginal_coverage = np.mean(covered)
|
| 205 |
+
|
| 206 |
+
# Average set size
|
| 207 |
+
set_sizes = [r.set_size for r in results]
|
| 208 |
+
avg_set_size = np.mean(set_sizes)
|
| 209 |
+
|
| 210 |
+
# Coverage by true label value
|
| 211 |
+
pos_mask = test_labels == 1
|
| 212 |
+
neg_mask = test_labels == 0
|
| 213 |
+
pos_coverage = np.mean([c for c, m in zip(covered, pos_mask) if m]) if pos_mask.sum() > 0 else 0.0
|
| 214 |
+
neg_coverage = np.mean([c for c, m in zip(covered, neg_mask) if m]) if neg_mask.sum() > 0 else 0.0
|
| 215 |
+
|
| 216 |
+
# Set size distribution
|
| 217 |
+
size_counts = {}
|
| 218 |
+
for s in set_sizes:
|
| 219 |
+
size_counts[s] = size_counts.get(s, 0) + 1
|
| 220 |
+
|
| 221 |
+
# Efficiency: lower set sizes = more informative
|
| 222 |
+
efficiency = 1.0 - (avg_set_size / len(self.TIERS))
|
| 223 |
+
|
| 224 |
+
# Interval width statistics
|
| 225 |
+
widths = [r.upper_bound - r.lower_bound for r in results]
|
| 226 |
+
|
| 227 |
+
metrics = {
|
| 228 |
+
"alpha": self.alpha,
|
| 229 |
+
"target_coverage": 1 - self.alpha,
|
| 230 |
+
"marginal_coverage": float(marginal_coverage),
|
| 231 |
+
"coverage_guarantee_met": bool(marginal_coverage >= (1 - self.alpha - 0.01)),
|
| 232 |
+
"avg_set_size": float(avg_set_size),
|
| 233 |
+
"efficiency": float(efficiency),
|
| 234 |
+
"positive_coverage": float(pos_coverage),
|
| 235 |
+
"negative_coverage": float(neg_coverage),
|
| 236 |
+
"set_size_distribution": {str(k): v for k, v in sorted(size_counts.items())},
|
| 237 |
+
"n_test": len(test_labels),
|
| 238 |
+
"mean_interval_width": float(np.mean(widths)),
|
| 239 |
+
"median_interval_width": float(np.median(widths)),
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
print(f"\n Conformal Prediction Evaluation (alpha={self.alpha}):")
|
| 243 |
+
print(f" Target coverage: {1 - self.alpha:.1%}")
|
| 244 |
+
print(f" Marginal coverage: {marginal_coverage:.1%} "
|
| 245 |
+
f"{'OK' if metrics['coverage_guarantee_met'] else 'VIOLATION'}")
|
| 246 |
+
print(f" Positive coverage: {pos_coverage:.1%}")
|
| 247 |
+
print(f" Negative coverage: {neg_coverage:.1%}")
|
| 248 |
+
print(f" Avg set size: {avg_set_size:.2f} / {len(self.TIERS)} tiers")
|
| 249 |
+
print(f" Efficiency: {efficiency:.1%}")
|
| 250 |
+
print(f" Mean interval: [{np.mean([r.lower_bound for r in results]):.3f}, "
|
| 251 |
+
f"{np.mean([r.upper_bound for r in results]):.3f}]")
|
| 252 |
+
print(f" Set size dist: {size_counts}")
|
| 253 |
+
|
| 254 |
+
return metrics
|
| 255 |
+
|
| 256 |
+
def save_state(self) -> dict:
|
| 257 |
+
"""Serialize calibration state for checkpoint saving."""
|
| 258 |
+
if not self.is_calibrated:
|
| 259 |
+
return {"is_calibrated": False}
|
| 260 |
+
return {
|
| 261 |
+
"is_calibrated": True,
|
| 262 |
+
"alpha": self.alpha,
|
| 263 |
+
"q_hat": self.q_hat,
|
| 264 |
+
"q_residual": self.q_residual,
|
| 265 |
+
"n_cal": self.n_cal,
|
| 266 |
+
"tiers": {k: list(v) for k, v in self.TIERS.items()},
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
@classmethod
|
| 270 |
+
def from_state(cls, state: dict) -> "ConformalPredictor":
|
| 271 |
+
"""Restore from serialized state."""
|
| 272 |
+
obj = cls()
|
| 273 |
+
if state.get("is_calibrated", False):
|
| 274 |
+
obj.alpha = state["alpha"]
|
| 275 |
+
obj.q_hat = state["q_hat"]
|
| 276 |
+
obj.q_residual = state["q_residual"]
|
| 277 |
+
obj.n_cal = state["n_cal"]
|
| 278 |
+
obj.is_calibrated = True
|
| 279 |
+
return obj
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def run_conformal_at_multiple_levels(
|
| 283 |
+
cal_probs: np.ndarray,
|
| 284 |
+
cal_labels: np.ndarray,
|
| 285 |
+
test_probs: np.ndarray,
|
| 286 |
+
test_labels: np.ndarray,
|
| 287 |
+
alphas: list[float] = None,
|
| 288 |
+
) -> dict:
|
| 289 |
+
"""Run conformal prediction at multiple coverage levels.
|
| 290 |
+
|
| 291 |
+
Useful for reporting: "at 90% coverage, avg set size = X;
|
| 292 |
+
at 95%, avg set size = Y; at 99%, avg set size = Z"
|
| 293 |
+
"""
|
| 294 |
+
if alphas is None:
|
| 295 |
+
alphas = [0.01, 0.05, 0.10, 0.20]
|
| 296 |
+
|
| 297 |
+
all_results = {}
|
| 298 |
+
for alpha in alphas:
|
| 299 |
+
cp = ConformalPredictor()
|
| 300 |
+
cp.calibrate(cal_probs, cal_labels, alpha=alpha)
|
| 301 |
+
eval_metrics = cp.evaluate(test_probs, test_labels)
|
| 302 |
+
all_results[f"alpha_{alpha}"] = {
|
| 303 |
+
"conformal_metrics": eval_metrics,
|
| 304 |
+
"conformal_state": cp.save_state(),
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
return all_results
|
src/evaluation/metrics.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code -- 2026-02-08
|
| 2 |
+
"""Evaluation metrics for conjunction prediction models."""
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from sklearn.metrics import (
|
| 6 |
+
average_precision_score,
|
| 7 |
+
roc_auc_score,
|
| 8 |
+
f1_score,
|
| 9 |
+
precision_recall_curve,
|
| 10 |
+
mean_absolute_error,
|
| 11 |
+
mean_squared_error,
|
| 12 |
+
classification_report,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def find_optimal_threshold(y_true: np.ndarray, y_prob: np.ndarray) -> tuple[float, float]:
|
| 17 |
+
"""Find the threshold that maximizes F1 score on the precision-recall curve."""
|
| 18 |
+
precisions, recalls, thresholds = precision_recall_curve(y_true, y_prob)
|
| 19 |
+
# precision_recall_curve returns len(thresholds) = len(precisions) - 1
|
| 20 |
+
# Compute F1 for each threshold
|
| 21 |
+
f1_scores = 2 * (precisions[:-1] * recalls[:-1]) / (precisions[:-1] + recalls[:-1] + 1e-8)
|
| 22 |
+
best_idx = np.argmax(f1_scores)
|
| 23 |
+
return float(thresholds[best_idx]), float(f1_scores[best_idx])
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def evaluate_risk(y_true: np.ndarray, y_prob: np.ndarray, threshold: float = 0.5) -> dict:
|
| 27 |
+
"""
|
| 28 |
+
Evaluate risk classification predictions.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
y_true: binary ground truth labels
|
| 32 |
+
y_prob: predicted probabilities
|
| 33 |
+
threshold: classification threshold (used for f1_at_50)
|
| 34 |
+
|
| 35 |
+
Returns: dict of metrics including optimal threshold F1
|
| 36 |
+
"""
|
| 37 |
+
y_pred_fixed = (y_prob >= threshold).astype(int)
|
| 38 |
+
|
| 39 |
+
results = {
|
| 40 |
+
"auc_pr": float(average_precision_score(y_true, y_prob)) if y_true.sum() > 0 else 0.0,
|
| 41 |
+
"auc_roc": float(roc_auc_score(y_true, y_prob)) if len(np.unique(y_true)) > 1 else 0.0,
|
| 42 |
+
"f1_at_50": float(f1_score(y_true, y_pred_fixed, zero_division=0)),
|
| 43 |
+
"n_positive": int(y_true.sum()),
|
| 44 |
+
"n_total": int(len(y_true)),
|
| 45 |
+
"pos_rate": float(y_true.mean()),
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
# Find optimal threshold that maximizes F1
|
| 49 |
+
if y_true.sum() > 0:
|
| 50 |
+
opt_threshold, opt_f1 = find_optimal_threshold(y_true, y_prob)
|
| 51 |
+
results["f1"] = opt_f1
|
| 52 |
+
results["optimal_threshold"] = opt_threshold
|
| 53 |
+
results["threshold"] = opt_threshold
|
| 54 |
+
else:
|
| 55 |
+
results["f1"] = results["f1_at_50"]
|
| 56 |
+
results["optimal_threshold"] = threshold
|
| 57 |
+
results["threshold"] = threshold
|
| 58 |
+
|
| 59 |
+
# Recall at fixed precision levels
|
| 60 |
+
if y_true.sum() > 0:
|
| 61 |
+
precisions, recalls, thresholds = precision_recall_curve(y_true, y_prob)
|
| 62 |
+
for target_precision in [0.3, 0.5, 0.7]:
|
| 63 |
+
mask = precisions >= target_precision
|
| 64 |
+
if mask.any():
|
| 65 |
+
best_recall = recalls[mask].max()
|
| 66 |
+
results[f"recall_at_prec_{int(target_precision*100)}"] = float(best_recall)
|
| 67 |
+
else:
|
| 68 |
+
results[f"recall_at_prec_{int(target_precision*100)}"] = 0.0
|
| 69 |
+
|
| 70 |
+
return results
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def evaluate_miss_distance(y_true_log: np.ndarray, y_pred_log: np.ndarray) -> dict:
|
| 74 |
+
"""
|
| 75 |
+
Evaluate miss distance regression (log-scale).
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
y_true_log: log1p(miss_distance_km) ground truth
|
| 79 |
+
y_pred_log: log1p(miss_distance_km) predictions
|
| 80 |
+
|
| 81 |
+
Returns: dict of metrics
|
| 82 |
+
"""
|
| 83 |
+
mae_log = float(mean_absolute_error(y_true_log, y_pred_log))
|
| 84 |
+
rmse_log = float(np.sqrt(mean_squared_error(y_true_log, y_pred_log)))
|
| 85 |
+
|
| 86 |
+
# Convert back to km for interpretable metrics
|
| 87 |
+
y_true_km = np.expm1(y_true_log)
|
| 88 |
+
y_pred_km = np.expm1(y_pred_log)
|
| 89 |
+
mae_km = float(mean_absolute_error(y_true_km, y_pred_km))
|
| 90 |
+
|
| 91 |
+
return {
|
| 92 |
+
"mae_log": mae_log,
|
| 93 |
+
"rmse_log": rmse_log,
|
| 94 |
+
"mae_km": mae_km,
|
| 95 |
+
"median_abs_error_km": float(np.median(np.abs(y_true_km - y_pred_km))),
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def full_evaluation(
|
| 100 |
+
model_name: str,
|
| 101 |
+
y_risk_true: np.ndarray,
|
| 102 |
+
y_risk_prob: np.ndarray,
|
| 103 |
+
y_miss_true_log: np.ndarray,
|
| 104 |
+
y_miss_pred_log: np.ndarray,
|
| 105 |
+
) -> dict:
|
| 106 |
+
"""Run full evaluation suite for a model."""
|
| 107 |
+
risk_metrics = evaluate_risk(y_risk_true, y_risk_prob)
|
| 108 |
+
miss_metrics = evaluate_miss_distance(y_miss_true_log, y_miss_pred_log)
|
| 109 |
+
|
| 110 |
+
results = {"model": model_name, **risk_metrics, **miss_metrics}
|
| 111 |
+
|
| 112 |
+
print(f"\n{'='*60}")
|
| 113 |
+
print(f" {model_name}")
|
| 114 |
+
print(f"{'='*60}")
|
| 115 |
+
print(f" Risk Classification:")
|
| 116 |
+
print(f" AUC-PR: {risk_metrics['auc_pr']:.4f}")
|
| 117 |
+
print(f" AUC-ROC: {risk_metrics['auc_roc']:.4f}")
|
| 118 |
+
print(f" F1 (opt): {risk_metrics['f1']:.4f} (threshold={risk_metrics.get('optimal_threshold', 0.5):.3f})")
|
| 119 |
+
print(f" F1 (0.50): {risk_metrics['f1_at_50']:.4f}")
|
| 120 |
+
print(f" Positives: {risk_metrics['n_positive']}/{risk_metrics['n_total']} "
|
| 121 |
+
f"({risk_metrics['pos_rate']:.1%})")
|
| 122 |
+
print(f" Miss Distance:")
|
| 123 |
+
print(f" MAE (log): {miss_metrics['mae_log']:.4f}")
|
| 124 |
+
print(f" MAE (km): {miss_metrics['mae_km']:.2f}")
|
| 125 |
+
print(f" Median AE: {miss_metrics['median_abs_error_km']:.2f} km")
|
| 126 |
+
print(f"{'='*60}")
|
| 127 |
+
|
| 128 |
+
return results
|
src/evaluation/staleness.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code -- 2026-02-13
|
| 2 |
+
"""TLE Staleness Sensitivity Experiment.
|
| 3 |
+
|
| 4 |
+
Evaluates how model performance degrades as CDM data becomes stale.
|
| 5 |
+
Simulates staleness by filtering CDM sequences to only include updates
|
| 6 |
+
received at least `cutoff_days` before TCA.
|
| 7 |
+
|
| 8 |
+
The Kelvins test set has time_to_tca in [2.0, 7.0] days, so meaningful
|
| 9 |
+
cutoffs are in that range. A cutoff of 2.0 keeps all data (baseline),
|
| 10 |
+
while a cutoff of 6.0 keeps only the earliest CDMs.
|
| 11 |
+
|
| 12 |
+
Ground-truth labels always come from the ORIGINAL (untruncated) test set —
|
| 13 |
+
we're measuring how well models predict with less-recent information.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pandas as pd
|
| 18 |
+
import torch
|
| 19 |
+
from torch.utils.data import DataLoader
|
| 20 |
+
|
| 21 |
+
from src.data.cdm_loader import build_events, events_to_flat_features, get_feature_columns
|
| 22 |
+
from src.data.sequence_builder import CDMSequenceDataset
|
| 23 |
+
from src.evaluation.metrics import evaluate_risk
|
| 24 |
+
|
| 25 |
+
# Staleness cutoffs (days before TCA)
|
| 26 |
+
# 2.0 = keep all data (baseline), 6.0 = only very early CDMs
|
| 27 |
+
DEFAULT_CUTOFFS = [2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 6.0]
|
| 28 |
+
QUICK_CUTOFFS = [2.0, 4.0, 6.0]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def truncate_cdm_dataframe(df: pd.DataFrame, cutoff_days: float) -> pd.DataFrame:
|
| 32 |
+
"""Filter CDM rows to only those with time_to_tca >= cutoff_days.
|
| 33 |
+
|
| 34 |
+
Simulates data staleness: if cutoff=4.0, the model only sees CDMs
|
| 35 |
+
that arrived 4+ days before closest approach.
|
| 36 |
+
"""
|
| 37 |
+
return df[df["time_to_tca"] >= cutoff_days].copy()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_ground_truth_labels(df: pd.DataFrame) -> dict:
|
| 41 |
+
"""Extract per-event ground truth labels from the FULL (untruncated) dataset.
|
| 42 |
+
|
| 43 |
+
Labels come from the final CDM per event (closest to TCA).
|
| 44 |
+
Returns: {event_id: {"risk_label": int, "miss_log": float, "altitude_km": float}}
|
| 45 |
+
"""
|
| 46 |
+
labels = {}
|
| 47 |
+
for event_id, group in df.groupby("event_id"):
|
| 48 |
+
group = group.sort_values("time_to_tca", ascending=True)
|
| 49 |
+
final = group.iloc[0]
|
| 50 |
+
risk_label = 1 if final["risk"] > -5 else 0
|
| 51 |
+
miss_log = float(np.log1p(max(final.get("miss_distance", 0.0), 0.0)))
|
| 52 |
+
alt = float(final.get("t_h_apo", 0.0))
|
| 53 |
+
labels[int(event_id)] = {
|
| 54 |
+
"risk_label": risk_label,
|
| 55 |
+
"miss_log": miss_log,
|
| 56 |
+
"altitude_km": alt,
|
| 57 |
+
}
|
| 58 |
+
return labels
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def evaluate_baseline_at_cutoff(baseline_model, ground_truth: dict, cutoff: float) -> dict:
|
| 62 |
+
"""Evaluate baseline model. Uses altitude only, unaffected by staleness."""
|
| 63 |
+
altitudes = np.array([gt["altitude_km"] for gt in ground_truth.values()])
|
| 64 |
+
y_true = np.array([gt["risk_label"] for gt in ground_truth.values()])
|
| 65 |
+
risk_probs, _ = baseline_model.predict(altitudes)
|
| 66 |
+
metrics = evaluate_risk(y_true, risk_probs)
|
| 67 |
+
metrics["cutoff"] = cutoff
|
| 68 |
+
metrics["n_events"] = len(y_true)
|
| 69 |
+
return metrics
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def evaluate_xgboost_at_cutoff(
|
| 73 |
+
xgboost_model,
|
| 74 |
+
truncated_df: pd.DataFrame,
|
| 75 |
+
ground_truth: dict,
|
| 76 |
+
feature_cols: list[str],
|
| 77 |
+
cutoff: float,
|
| 78 |
+
) -> dict:
|
| 79 |
+
"""Evaluate XGBoost on truncated CDM data."""
|
| 80 |
+
events = build_events(truncated_df, feature_cols)
|
| 81 |
+
if len(events) == 0:
|
| 82 |
+
return {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff}
|
| 83 |
+
|
| 84 |
+
X, _, _ = events_to_flat_features(events)
|
| 85 |
+
|
| 86 |
+
# Pad features if model was trained on augmented data with more columns
|
| 87 |
+
expected_features = xgboost_model.scaler.n_features_in_
|
| 88 |
+
if X.shape[1] < expected_features:
|
| 89 |
+
padding = np.zeros((X.shape[0], expected_features - X.shape[1]), dtype=X.dtype)
|
| 90 |
+
X = np.hstack([X, padding])
|
| 91 |
+
|
| 92 |
+
event_ids = [e.event_id for e in events]
|
| 93 |
+
valid_mask = np.array([eid in ground_truth for eid in event_ids])
|
| 94 |
+
X = X[valid_mask]
|
| 95 |
+
valid_ids = [eid for eid in event_ids if eid in ground_truth]
|
| 96 |
+
y_true = np.array([ground_truth[eid]["risk_label"] for eid in valid_ids])
|
| 97 |
+
|
| 98 |
+
if len(y_true) == 0 or y_true.sum() == 0:
|
| 99 |
+
return {"auc_pr": 0.0, "f1": 0.0, "n_events": len(y_true), "cutoff": cutoff}
|
| 100 |
+
|
| 101 |
+
# Pad features if model expects more (e.g., trained on augmented data)
|
| 102 |
+
expected = xgboost_model.scaler.n_features_in_
|
| 103 |
+
if X.shape[1] < expected:
|
| 104 |
+
pad_width = expected - X.shape[1]
|
| 105 |
+
X = np.pad(X, ((0, 0), (0, pad_width)), constant_values=0)
|
| 106 |
+
elif X.shape[1] > expected:
|
| 107 |
+
X = X[:, :expected]
|
| 108 |
+
|
| 109 |
+
risk_probs = xgboost_model.predict_risk(X)
|
| 110 |
+
metrics = evaluate_risk(y_true, risk_probs)
|
| 111 |
+
metrics["cutoff"] = cutoff
|
| 112 |
+
metrics["n_events"] = len(y_true)
|
| 113 |
+
return metrics
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def evaluate_pitft_at_cutoff(
|
| 117 |
+
model,
|
| 118 |
+
truncated_df: pd.DataFrame,
|
| 119 |
+
ground_truth: dict,
|
| 120 |
+
train_ds: CDMSequenceDataset,
|
| 121 |
+
device: torch.device,
|
| 122 |
+
temperature: float = 1.0,
|
| 123 |
+
cutoff: float = 0.0,
|
| 124 |
+
batch_size: int = 128,
|
| 125 |
+
) -> dict:
|
| 126 |
+
"""Evaluate PI-TFT on truncated CDM data with temperature scaling."""
|
| 127 |
+
# Ensure all required columns exist (pad missing with 0)
|
| 128 |
+
df = truncated_df.copy()
|
| 129 |
+
for col in train_ds.temporal_cols + train_ds.static_cols:
|
| 130 |
+
if col not in df.columns:
|
| 131 |
+
df[col] = 0.0
|
| 132 |
+
|
| 133 |
+
test_ds = CDMSequenceDataset(
|
| 134 |
+
df,
|
| 135 |
+
temporal_cols=train_ds.temporal_cols,
|
| 136 |
+
static_cols=train_ds.static_cols,
|
| 137 |
+
)
|
| 138 |
+
test_ds.set_normalization(train_ds)
|
| 139 |
+
|
| 140 |
+
if len(test_ds) == 0:
|
| 141 |
+
return {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff}
|
| 142 |
+
|
| 143 |
+
# Get event IDs from the dataset
|
| 144 |
+
event_ids = [e["event_id"] for e in test_ds.events]
|
| 145 |
+
|
| 146 |
+
loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0)
|
| 147 |
+
|
| 148 |
+
model.eval()
|
| 149 |
+
all_probs = []
|
| 150 |
+
|
| 151 |
+
with torch.no_grad():
|
| 152 |
+
for batch in loader:
|
| 153 |
+
temporal = batch["temporal"].to(device)
|
| 154 |
+
static = batch["static"].to(device)
|
| 155 |
+
tca = batch["time_to_tca"].to(device)
|
| 156 |
+
mask = batch["mask"].to(device)
|
| 157 |
+
|
| 158 |
+
risk_logit, _, _, _ = model(temporal, static, tca, mask)
|
| 159 |
+
probs = torch.sigmoid(risk_logit / temperature).cpu().numpy().flatten()
|
| 160 |
+
all_probs.append(probs)
|
| 161 |
+
|
| 162 |
+
risk_probs = np.concatenate(all_probs)
|
| 163 |
+
|
| 164 |
+
# Match predictions to ground truth
|
| 165 |
+
valid_mask = np.array([eid in ground_truth for eid in event_ids])
|
| 166 |
+
risk_probs = risk_probs[valid_mask]
|
| 167 |
+
valid_ids = [eid for eid in event_ids if eid in ground_truth]
|
| 168 |
+
y_true = np.array([ground_truth[eid]["risk_label"] for eid in valid_ids])
|
| 169 |
+
|
| 170 |
+
if len(y_true) == 0 or y_true.sum() == 0:
|
| 171 |
+
return {"auc_pr": 0.0, "f1": 0.0, "n_events": len(y_true), "cutoff": cutoff}
|
| 172 |
+
|
| 173 |
+
metrics = evaluate_risk(y_true, risk_probs)
|
| 174 |
+
metrics["cutoff"] = cutoff
|
| 175 |
+
metrics["n_events"] = int(len(y_true))
|
| 176 |
+
return metrics
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def run_staleness_experiment(
|
| 180 |
+
baseline_model,
|
| 181 |
+
xgboost_model,
|
| 182 |
+
pitft_model,
|
| 183 |
+
pitft_checkpoint: dict,
|
| 184 |
+
test_df: pd.DataFrame,
|
| 185 |
+
train_ds: CDMSequenceDataset,
|
| 186 |
+
feature_cols: list[str],
|
| 187 |
+
device: torch.device,
|
| 188 |
+
cutoffs: list[float] = None,
|
| 189 |
+
quick: bool = False,
|
| 190 |
+
) -> dict:
|
| 191 |
+
"""Run the full staleness experiment across all cutoffs and models.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
baseline_model: OrbitalShellBaseline instance
|
| 195 |
+
xgboost_model: XGBoostConjunctionModel instance
|
| 196 |
+
pitft_model: PhysicsInformedTFT (eval mode), or None to skip
|
| 197 |
+
pitft_checkpoint: checkpoint dict with temperature
|
| 198 |
+
test_df: ORIGINAL (untruncated) test DataFrame
|
| 199 |
+
train_ds: CDMSequenceDataset from training data (for normalization)
|
| 200 |
+
feature_cols: list of feature column names for XGBoost
|
| 201 |
+
device: torch device
|
| 202 |
+
cutoffs: list of staleness cutoffs (days before TCA)
|
| 203 |
+
quick: if True, use fewer cutoffs
|
| 204 |
+
"""
|
| 205 |
+
if cutoffs is None:
|
| 206 |
+
cutoffs = QUICK_CUTOFFS if quick else DEFAULT_CUTOFFS
|
| 207 |
+
|
| 208 |
+
ground_truth = get_ground_truth_labels(test_df)
|
| 209 |
+
n_pos = sum(1 for gt in ground_truth.values() if gt["risk_label"] == 1)
|
| 210 |
+
print(f"\nGround truth: {len(ground_truth)} events, {n_pos} positive")
|
| 211 |
+
|
| 212 |
+
temperature = 1.0
|
| 213 |
+
if pitft_checkpoint:
|
| 214 |
+
temperature = pitft_checkpoint.get("temperature", 1.0)
|
| 215 |
+
|
| 216 |
+
results = {
|
| 217 |
+
"cutoffs": cutoffs,
|
| 218 |
+
"n_test_events": len(ground_truth),
|
| 219 |
+
"n_positive": n_pos,
|
| 220 |
+
"baseline": [],
|
| 221 |
+
"xgboost": [],
|
| 222 |
+
"pitft": [],
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
for cutoff in cutoffs:
|
| 226 |
+
print(f"\n{'='*50}")
|
| 227 |
+
print(f"Staleness cutoff: {cutoff:.1f} days")
|
| 228 |
+
print(f"{'='*50}")
|
| 229 |
+
|
| 230 |
+
truncated = truncate_cdm_dataframe(test_df, cutoff)
|
| 231 |
+
n_events = truncated["event_id"].nunique()
|
| 232 |
+
n_rows = len(truncated)
|
| 233 |
+
print(f" Surviving: {n_events} events, {n_rows} CDMs")
|
| 234 |
+
|
| 235 |
+
# Baseline (uses altitude only — constant across cutoffs)
|
| 236 |
+
bl = evaluate_baseline_at_cutoff(baseline_model, ground_truth, cutoff)
|
| 237 |
+
results["baseline"].append(bl)
|
| 238 |
+
print(f" Baseline AUC-PR={bl.get('auc_pr', 0):.4f}, F1={bl.get('f1', 0):.4f}")
|
| 239 |
+
|
| 240 |
+
# XGBoost
|
| 241 |
+
if n_events > 0:
|
| 242 |
+
xgb = evaluate_xgboost_at_cutoff(
|
| 243 |
+
xgboost_model, truncated, ground_truth, feature_cols, cutoff
|
| 244 |
+
)
|
| 245 |
+
else:
|
| 246 |
+
xgb = {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff}
|
| 247 |
+
results["xgboost"].append(xgb)
|
| 248 |
+
print(f" XGBoost AUC-PR={xgb.get('auc_pr', 0):.4f}, "
|
| 249 |
+
f"F1={xgb.get('f1', 0):.4f} ({xgb.get('n_events', 0)} events)")
|
| 250 |
+
|
| 251 |
+
# PI-TFT
|
| 252 |
+
if n_events > 0 and pitft_model is not None:
|
| 253 |
+
tft = evaluate_pitft_at_cutoff(
|
| 254 |
+
pitft_model, truncated, ground_truth, train_ds,
|
| 255 |
+
device, temperature=temperature, cutoff=cutoff,
|
| 256 |
+
)
|
| 257 |
+
else:
|
| 258 |
+
tft = {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff}
|
| 259 |
+
results["pitft"].append(tft)
|
| 260 |
+
print(f" PI-TFT AUC-PR={tft.get('auc_pr', 0):.4f}, "
|
| 261 |
+
f"F1={tft.get('f1', 0):.4f}")
|
| 262 |
+
|
| 263 |
+
return results
|
src/model/__init__.py
ADDED
|
File without changes
|
src/model/baseline.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code -- 2026-02-08
|
| 2 |
+
"""Model 1: Naive Baseline -- Orbital Shell Density Prior.
|
| 3 |
+
|
| 4 |
+
Predicts collision risk based solely on the altitude band of the conjunction,
|
| 5 |
+
using historical base rates. This establishes that altitude alone is predictive
|
| 6 |
+
(LEO is more crowded) but insufficient for actionable conjunction assessment.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import numpy as np
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class OrbitalShellBaseline:
|
| 16 |
+
"""
|
| 17 |
+
Altitude-band collision rate baseline.
|
| 18 |
+
|
| 19 |
+
For any conjunction event, predict the average risk and miss distance
|
| 20 |
+
for that altitude regime. Bins events into 50km altitude bands.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, bin_width_km: float = 50.0):
|
| 24 |
+
self.bin_width = bin_width_km
|
| 25 |
+
self.bins: dict[int, dict] = {}
|
| 26 |
+
self.global_stats: dict = {}
|
| 27 |
+
|
| 28 |
+
def _altitude_to_bin(self, alt_km: float) -> int:
|
| 29 |
+
return int(round(alt_km / self.bin_width) * self.bin_width)
|
| 30 |
+
|
| 31 |
+
def fit(self, altitudes: np.ndarray, y_risk: np.ndarray, y_miss_log: np.ndarray):
|
| 32 |
+
"""
|
| 33 |
+
Fit baseline from altitude array and labels.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
altitudes: altitude in km for each event
|
| 37 |
+
y_risk: binary risk labels
|
| 38 |
+
y_miss_log: log1p(miss_distance_km) targets
|
| 39 |
+
"""
|
| 40 |
+
# Global fallback stats
|
| 41 |
+
self.global_stats = {
|
| 42 |
+
"mean_risk": float(np.mean(y_risk)),
|
| 43 |
+
"mean_miss_log": float(np.mean(y_miss_log)),
|
| 44 |
+
"count": int(len(y_risk)),
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
# Per-bin statistics
|
| 48 |
+
bin_data = defaultdict(lambda: {"risks": [], "misses": []})
|
| 49 |
+
|
| 50 |
+
for alt, risk, miss in zip(altitudes, y_risk, y_miss_log):
|
| 51 |
+
b = self._altitude_to_bin(alt)
|
| 52 |
+
bin_data[b]["risks"].append(risk)
|
| 53 |
+
bin_data[b]["misses"].append(miss)
|
| 54 |
+
|
| 55 |
+
self.bins = {}
|
| 56 |
+
for b, data in bin_data.items():
|
| 57 |
+
self.bins[b] = {
|
| 58 |
+
"mean_risk": float(np.mean(data["risks"])),
|
| 59 |
+
"mean_miss_log": float(np.mean(data["misses"])),
|
| 60 |
+
"count": len(data["risks"]),
|
| 61 |
+
"risk_rate": float(np.sum(data["risks"]) / len(data["risks"])),
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
print(f"Baseline fit: {len(self.bins)} altitude bins, "
|
| 65 |
+
f"global risk rate = {self.global_stats['mean_risk']:.4f}")
|
| 66 |
+
|
| 67 |
+
def predict(self, altitudes: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 68 |
+
"""
|
| 69 |
+
Predict risk probability and log miss distance for each altitude.
|
| 70 |
+
|
| 71 |
+
Returns: (risk_probs, miss_log_preds)
|
| 72 |
+
"""
|
| 73 |
+
risk_preds = []
|
| 74 |
+
miss_preds = []
|
| 75 |
+
|
| 76 |
+
for alt in altitudes:
|
| 77 |
+
b = self._altitude_to_bin(alt)
|
| 78 |
+
if b in self.bins:
|
| 79 |
+
risk_preds.append(self.bins[b]["risk_rate"])
|
| 80 |
+
miss_preds.append(self.bins[b]["mean_miss_log"])
|
| 81 |
+
else:
|
| 82 |
+
risk_preds.append(self.global_stats["mean_risk"])
|
| 83 |
+
miss_preds.append(self.global_stats["mean_miss_log"])
|
| 84 |
+
|
| 85 |
+
return np.array(risk_preds), np.array(miss_preds)
|
| 86 |
+
|
| 87 |
+
def save(self, path: Path):
|
| 88 |
+
"""Save model to JSON."""
|
| 89 |
+
data = {
|
| 90 |
+
"bin_width": self.bin_width,
|
| 91 |
+
"bins": {str(k): v for k, v in self.bins.items()},
|
| 92 |
+
"global_stats": self.global_stats,
|
| 93 |
+
}
|
| 94 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 95 |
+
with open(path, "w") as f:
|
| 96 |
+
json.dump(data, f, indent=2)
|
| 97 |
+
print(f"Baseline saved to {path}")
|
| 98 |
+
|
| 99 |
+
@classmethod
|
| 100 |
+
def load(cls, path: Path) -> "OrbitalShellBaseline":
|
| 101 |
+
"""Load model from JSON."""
|
| 102 |
+
with open(path) as f:
|
| 103 |
+
data = json.load(f)
|
| 104 |
+
model = cls(bin_width_km=data["bin_width"])
|
| 105 |
+
model.bins = {int(k): v for k, v in data["bins"].items()}
|
| 106 |
+
model.global_stats = data["global_stats"]
|
| 107 |
+
return model
|
src/model/classical.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code -- 2026-02-08
|
| 2 |
+
"""Model 2: Classical ML -- XGBoost on engineered CDM features.
|
| 3 |
+
|
| 4 |
+
Dual-head model:
|
| 5 |
+
- Risk classifier (binary: high-risk vs safe)
|
| 6 |
+
- Miss distance regressor (log-scale km)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import pickle
|
| 10 |
+
import numpy as np
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from xgboost import XGBClassifier, XGBRegressor
|
| 13 |
+
from sklearn.preprocessing import StandardScaler
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class XGBoostConjunctionModel:
|
| 17 |
+
"""XGBoost with engineered CDM features."""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.scaler = StandardScaler()
|
| 21 |
+
|
| 22 |
+
self.risk_classifier = XGBClassifier(
|
| 23 |
+
n_estimators=500,
|
| 24 |
+
max_depth=8,
|
| 25 |
+
learning_rate=0.05,
|
| 26 |
+
scale_pos_weight=50, # severe class imbalance
|
| 27 |
+
eval_metric="aucpr",
|
| 28 |
+
tree_method="hist",
|
| 29 |
+
random_state=42,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
self.miss_regressor = XGBRegressor(
|
| 33 |
+
n_estimators=500,
|
| 34 |
+
max_depth=8,
|
| 35 |
+
learning_rate=0.05,
|
| 36 |
+
objective="reg:squaredlogerror",
|
| 37 |
+
tree_method="hist",
|
| 38 |
+
random_state=42,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def fit(
|
| 42 |
+
self,
|
| 43 |
+
X_train: np.ndarray,
|
| 44 |
+
y_risk: np.ndarray,
|
| 45 |
+
y_miss_log: np.ndarray,
|
| 46 |
+
X_val: np.ndarray = None,
|
| 47 |
+
y_risk_val: np.ndarray = None,
|
| 48 |
+
y_miss_val: np.ndarray = None,
|
| 49 |
+
):
|
| 50 |
+
"""Train both heads."""
|
| 51 |
+
# Scale features
|
| 52 |
+
X_scaled = self.scaler.fit_transform(X_train)
|
| 53 |
+
|
| 54 |
+
# Risk classifier
|
| 55 |
+
print(f"Training risk classifier (pos_rate={y_risk.mean():.4f}) ...")
|
| 56 |
+
eval_set = None
|
| 57 |
+
if X_val is not None:
|
| 58 |
+
eval_set = [(self.scaler.transform(X_val), y_risk_val)]
|
| 59 |
+
self.risk_classifier.fit(
|
| 60 |
+
X_scaled, y_risk,
|
| 61 |
+
eval_set=eval_set,
|
| 62 |
+
verbose=50,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Miss distance regressor (log-scale, must be > 0 for squaredlogerror)
|
| 66 |
+
y_miss_positive = np.clip(y_miss_log, 1e-6, None)
|
| 67 |
+
print("Training miss distance regressor ...")
|
| 68 |
+
eval_set_miss = None
|
| 69 |
+
if X_val is not None:
|
| 70 |
+
y_miss_val_pos = np.clip(y_miss_val, 1e-6, None)
|
| 71 |
+
eval_set_miss = [(self.scaler.transform(X_val), y_miss_val_pos)]
|
| 72 |
+
self.miss_regressor.fit(
|
| 73 |
+
X_scaled, y_miss_positive,
|
| 74 |
+
eval_set=eval_set_miss,
|
| 75 |
+
verbose=50,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def predict(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 79 |
+
"""
|
| 80 |
+
Predict risk probability and miss distance.
|
| 81 |
+
|
| 82 |
+
Returns: (risk_probs, miss_distance_km)
|
| 83 |
+
"""
|
| 84 |
+
X_scaled = self.scaler.transform(X)
|
| 85 |
+
risk_probs = self.risk_classifier.predict_proba(X_scaled)[:, 1]
|
| 86 |
+
miss_log = self.miss_regressor.predict(X_scaled)
|
| 87 |
+
miss_km = np.expm1(miss_log)
|
| 88 |
+
return risk_probs, miss_km
|
| 89 |
+
|
| 90 |
+
def predict_risk(self, X: np.ndarray) -> np.ndarray:
|
| 91 |
+
"""Predict risk probability only."""
|
| 92 |
+
X_scaled = self.scaler.transform(X)
|
| 93 |
+
return self.risk_classifier.predict_proba(X_scaled)[:, 1]
|
| 94 |
+
|
| 95 |
+
def save(self, path: Path):
|
| 96 |
+
"""Save all components."""
|
| 97 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 98 |
+
with open(path, "wb") as f:
|
| 99 |
+
pickle.dump({
|
| 100 |
+
"scaler": self.scaler,
|
| 101 |
+
"risk_classifier": self.risk_classifier,
|
| 102 |
+
"miss_regressor": self.miss_regressor,
|
| 103 |
+
}, f)
|
| 104 |
+
print(f"XGBoost model saved to {path}")
|
| 105 |
+
|
| 106 |
+
@classmethod
|
| 107 |
+
def load(cls, path: Path) -> "XGBoostConjunctionModel":
|
| 108 |
+
"""Load all components."""
|
| 109 |
+
with open(path, "rb") as f:
|
| 110 |
+
data = pickle.load(f)
|
| 111 |
+
model = cls()
|
| 112 |
+
model.scaler = data["scaler"]
|
| 113 |
+
model.risk_classifier = data["risk_classifier"]
|
| 114 |
+
model.miss_regressor = data["miss_regressor"]
|
| 115 |
+
return model
|
src/model/deep.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code -- 2026-02-08
|
| 2 |
+
"""Model 3: Physics-Informed Temporal Fusion Transformer (PI-TFT).
|
| 3 |
+
|
| 4 |
+
Architecture overview (think of it like reading serial lab values):
|
| 5 |
+
|
| 6 |
+
1. VARIABLE SELECTION: Not all 22 CDM features matter equally. The model
|
| 7 |
+
learns attention weights over features -- e.g., miss_distance and
|
| 8 |
+
covariance shrinkage rate might matter more than raw orbital elements.
|
| 9 |
+
This is like a doctor learning which labs to focus on.
|
| 10 |
+
|
| 11 |
+
2. STATIC CONTEXT: Object properties (altitude, size, eccentricity) don't
|
| 12 |
+
change between CDM updates. They're encoded once and injected as context
|
| 13 |
+
into the temporal processing. Like knowing the patient's age and history.
|
| 14 |
+
|
| 15 |
+
3. CONTINUOUS TIME EMBEDDING: CDMs arrive at irregular intervals (not evenly
|
| 16 |
+
spaced). Instead of positional encoding (position 1, 2, 3...), we embed
|
| 17 |
+
the actual time_to_tca value. The model knows "this CDM was 3.2 days
|
| 18 |
+
before closest approach" vs "this one was 0.5 days before."
|
| 19 |
+
|
| 20 |
+
4. TEMPORAL SELF-ATTENTION: The Transformer reads the full CDM sequence and
|
| 21 |
+
learns which updates were most informative. A sudden miss distance drop
|
| 22 |
+
at day -2 gets more attention than a stable reading at day -5.
|
| 23 |
+
|
| 24 |
+
5. PREDICTION HEADS: The final hidden state (from the most recent CDM)
|
| 25 |
+
feeds into two prediction heads:
|
| 26 |
+
- Risk classifier: sigmoid probability of high-risk collision
|
| 27 |
+
- Miss distance regressor: predicted log(miss distance in km)
|
| 28 |
+
|
| 29 |
+
6. PHYSICS LOSS: The training loss includes a penalty when the model predicts
|
| 30 |
+
a miss distance BELOW the Minimum Orbital Intersection Distance (MOID).
|
| 31 |
+
MOID is the closest the two orbits can geometrically get. Predicting
|
| 32 |
+
closer than MOID is physically impossible (without a maneuver), so we
|
| 33 |
+
penalize it. This is like penalizing a model for predicting negative
|
| 34 |
+
blood pressure -- constraining outputs to the physically possible range.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
import torch.nn.functional as F
|
| 40 |
+
import math
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class GatedResidualNetwork(nn.Module):
|
| 44 |
+
"""
|
| 45 |
+
Gated skip connection with ELU activation and layer norm.
|
| 46 |
+
|
| 47 |
+
Think of this as a "smart residual block" -- it learns how much of the
|
| 48 |
+
transformed input to mix with the original. The gate (sigmoid) controls
|
| 49 |
+
this: gate=0 means pass through unchanged, gate=1 means fully transformed.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, d_model: int, d_hidden: int = None, dropout: float = 0.1):
|
| 53 |
+
super().__init__()
|
| 54 |
+
d_hidden = d_hidden or d_model
|
| 55 |
+
self.fc1 = nn.Linear(d_model, d_hidden)
|
| 56 |
+
self.fc2 = nn.Linear(d_hidden, d_model)
|
| 57 |
+
self.gate_fc = nn.Linear(d_hidden, d_model)
|
| 58 |
+
self.norm = nn.LayerNorm(d_model)
|
| 59 |
+
self.dropout = nn.Dropout(dropout)
|
| 60 |
+
|
| 61 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 62 |
+
residual = x
|
| 63 |
+
h = F.elu(self.fc1(x))
|
| 64 |
+
h = self.dropout(h)
|
| 65 |
+
transform = self.fc2(h)
|
| 66 |
+
gate = torch.sigmoid(self.gate_fc(h))
|
| 67 |
+
return self.norm(residual + gate * transform)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class VariableSelectionNetwork(nn.Module):
|
| 71 |
+
"""
|
| 72 |
+
Learns which input features matter most via softmax attention.
|
| 73 |
+
|
| 74 |
+
For N input features, produces N attention weights that sum to 1.
|
| 75 |
+
Each feature is independently projected to d_model, then weighted
|
| 76 |
+
and summed. The weights are interpretable -- they tell you which
|
| 77 |
+
CDM columns the model found most predictive.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(self, n_features: int, d_model: int, dropout: float = 0.1):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.n_features = n_features
|
| 83 |
+
self.d_model = d_model
|
| 84 |
+
|
| 85 |
+
# Each feature gets its own linear projection: scalar -> d_model vector
|
| 86 |
+
self.feature_projections = nn.ModuleList([
|
| 87 |
+
nn.Linear(1, d_model) for _ in range(n_features)
|
| 88 |
+
])
|
| 89 |
+
|
| 90 |
+
# Gating network: takes flattened projections -> feature weights
|
| 91 |
+
self.gate_network = nn.Sequential(
|
| 92 |
+
nn.Linear(n_features * d_model, n_features),
|
| 93 |
+
nn.Softmax(dim=-1),
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
self.grn = GatedResidualNetwork(d_model, dropout=dropout)
|
| 97 |
+
|
| 98 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 99 |
+
"""
|
| 100 |
+
Args:
|
| 101 |
+
x: (..., n_features) — can be (B, F) for static or (B, S, F) for temporal
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
output: (..., d_model) — weighted combination of projected features
|
| 105 |
+
weights: (..., n_features) — attention weights (sum to 1)
|
| 106 |
+
"""
|
| 107 |
+
# Project each feature independently
|
| 108 |
+
# x[..., i:i+1] is the i-th feature, shape (..., 1)
|
| 109 |
+
projected = [proj(x[..., i:i+1]) for i, proj in enumerate(self.feature_projections)]
|
| 110 |
+
# projected[i] shape: (..., d_model)
|
| 111 |
+
|
| 112 |
+
# Stack for gating: (..., n_features, d_model)
|
| 113 |
+
stacked = torch.stack(projected, dim=-2)
|
| 114 |
+
|
| 115 |
+
# Flatten for gate computation: (..., n_features * d_model)
|
| 116 |
+
flat = stacked.reshape(*stacked.shape[:-2], -1)
|
| 117 |
+
weights = self.gate_network(flat) # (..., n_features)
|
| 118 |
+
|
| 119 |
+
# Weighted sum: (..., d_model)
|
| 120 |
+
output = (stacked * weights.unsqueeze(-1)).sum(dim=-2)
|
| 121 |
+
output = self.grn(output)
|
| 122 |
+
|
| 123 |
+
return output, weights
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class PhysicsInformedTFT(nn.Module):
|
| 127 |
+
"""
|
| 128 |
+
Physics-Informed Temporal Fusion Transformer for conjunction assessment.
|
| 129 |
+
|
| 130 |
+
Input flow:
|
| 131 |
+
temporal_features (B, S, F_t) → Variable Selection → time embedding → self-attention → attention pool → heads
|
| 132 |
+
static_features (B, F_s) → Variable Selection → context injection ↗
|
| 133 |
+
|
| 134 |
+
Output:
|
| 135 |
+
risk_logit: (B, 1) — raw logit for risk classification (apply sigmoid for probability)
|
| 136 |
+
miss_log: (B, 1) — predicted log1p(miss_distance_km)
|
| 137 |
+
pc_log10: (B, 1) — predicted log10(Pc) collision probability (when has_pc_head=True)
|
| 138 |
+
feature_weights: (B, S, F_t) — which temporal features mattered
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(
|
| 142 |
+
self,
|
| 143 |
+
n_temporal_features: int,
|
| 144 |
+
n_static_features: int,
|
| 145 |
+
d_model: int = 128,
|
| 146 |
+
n_heads: int = 4,
|
| 147 |
+
n_layers: int = 2,
|
| 148 |
+
dropout: float = 0.15,
|
| 149 |
+
max_seq_len: int = 30,
|
| 150 |
+
):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.d_model = d_model
|
| 153 |
+
self.max_seq_len = max_seq_len
|
| 154 |
+
|
| 155 |
+
# --- Variable Selection Networks ---
|
| 156 |
+
self.temporal_vsn = VariableSelectionNetwork(n_temporal_features, d_model, dropout)
|
| 157 |
+
self.static_vsn = VariableSelectionNetwork(n_static_features, d_model, dropout)
|
| 158 |
+
|
| 159 |
+
# --- Static context encoding ---
|
| 160 |
+
self.static_encoder = nn.Sequential(
|
| 161 |
+
nn.Linear(d_model, d_model),
|
| 162 |
+
nn.GELU(),
|
| 163 |
+
nn.Dropout(dropout),
|
| 164 |
+
)
|
| 165 |
+
# Static -> enrichment vector that's added to each temporal step
|
| 166 |
+
self.static_to_enrichment = nn.Linear(d_model, d_model)
|
| 167 |
+
|
| 168 |
+
# --- Continuous time embedding ---
|
| 169 |
+
# Instead of fixed positional encoding, we embed the actual time_to_tca
|
| 170 |
+
self.time_embedding = nn.Sequential(
|
| 171 |
+
nn.Linear(1, d_model // 2),
|
| 172 |
+
nn.GELU(),
|
| 173 |
+
nn.Linear(d_model // 2, d_model),
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# --- Transformer encoder layers ---
|
| 177 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 178 |
+
d_model=d_model,
|
| 179 |
+
nhead=n_heads,
|
| 180 |
+
dim_feedforward=d_model * 2,
|
| 181 |
+
dropout=dropout,
|
| 182 |
+
activation="gelu",
|
| 183 |
+
batch_first=True,
|
| 184 |
+
norm_first=True,
|
| 185 |
+
)
|
| 186 |
+
self.transformer_encoder = nn.TransformerEncoder(
|
| 187 |
+
encoder_layer, num_layers=n_layers
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# --- Pre/post attention processing ---
|
| 191 |
+
self.pre_attn_grn = GatedResidualNetwork(d_model, dropout=dropout)
|
| 192 |
+
self.post_attn_grn = GatedResidualNetwork(d_model, dropout=dropout)
|
| 193 |
+
|
| 194 |
+
# --- Attention-weighted pooling ---
|
| 195 |
+
# Learns which time steps matter most instead of just taking the last one.
|
| 196 |
+
# Softmax attention over all real positions, with padding masked out.
|
| 197 |
+
self.pool_attention = nn.Sequential(
|
| 198 |
+
nn.Linear(d_model, d_model // 2),
|
| 199 |
+
nn.Tanh(),
|
| 200 |
+
nn.Linear(d_model // 2, 1),
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# --- Prediction heads ---
|
| 204 |
+
self.risk_head = nn.Sequential(
|
| 205 |
+
nn.LayerNorm(d_model),
|
| 206 |
+
nn.Linear(d_model, 64),
|
| 207 |
+
nn.GELU(),
|
| 208 |
+
nn.Dropout(dropout),
|
| 209 |
+
nn.Linear(64, 1),
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
self.miss_head = nn.Sequential(
|
| 213 |
+
nn.LayerNorm(d_model),
|
| 214 |
+
nn.Linear(d_model, 64),
|
| 215 |
+
nn.GELU(),
|
| 216 |
+
nn.Dropout(dropout),
|
| 217 |
+
nn.Linear(64, 1),
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# --- Collision probability head ---
|
| 221 |
+
# Predicts log10(Pc) directly instead of binary risk classification.
|
| 222 |
+
# Pc ranges from ~1e-20 to ~1e-1, so log10 scale maps to [-20, -1].
|
| 223 |
+
# The Kelvins `risk` column is already log10(Pc).
|
| 224 |
+
self.pc_head = nn.Sequential(
|
| 225 |
+
nn.LayerNorm(d_model),
|
| 226 |
+
nn.Linear(d_model, 64),
|
| 227 |
+
nn.GELU(),
|
| 228 |
+
nn.Dropout(dropout),
|
| 229 |
+
nn.Linear(64, 1),
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
def encode_sequence(
|
| 233 |
+
self,
|
| 234 |
+
temporal_features: torch.Tensor, # (B, S, F_t)
|
| 235 |
+
static_features: torch.Tensor, # (B, F_s)
|
| 236 |
+
time_to_tca: torch.Tensor, # (B, S, 1)
|
| 237 |
+
mask: torch.Tensor, # (B, S) — True for real, False for padding
|
| 238 |
+
):
|
| 239 |
+
"""Encode CDM sequence into per-timestep hidden states.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
hidden: (B, S, D) per-timestep representations after Transformer
|
| 243 |
+
temporal_weights: (B, S, F_t) variable selection attention weights
|
| 244 |
+
"""
|
| 245 |
+
# 1. Variable selection -- learn which features matter
|
| 246 |
+
temporal_selected, temporal_weights = self.temporal_vsn(temporal_features)
|
| 247 |
+
# temporal_selected: (B, S, D), temporal_weights: (B, S, F_t)
|
| 248 |
+
|
| 249 |
+
static_selected, static_weights = self.static_vsn(static_features)
|
| 250 |
+
# static_selected: (B, D)
|
| 251 |
+
|
| 252 |
+
# 2. Static context -- compute enrichment vector
|
| 253 |
+
static_ctx = self.static_encoder(static_selected) # (B, D)
|
| 254 |
+
enrichment = self.static_to_enrichment(static_ctx) # (B, D)
|
| 255 |
+
|
| 256 |
+
# 3. Continuous time embedding
|
| 257 |
+
t_embed = self.time_embedding(time_to_tca) # (B, S, D)
|
| 258 |
+
|
| 259 |
+
# 4. Combine: temporal + time + static context
|
| 260 |
+
x = temporal_selected + t_embed + enrichment.unsqueeze(1)
|
| 261 |
+
|
| 262 |
+
# 5. Pre-attention GRN
|
| 263 |
+
x = self.pre_attn_grn(x)
|
| 264 |
+
|
| 265 |
+
# 6. Transformer self-attention
|
| 266 |
+
# Convert mask: True=real -> need to invert for PyTorch's src_key_padding_mask
|
| 267 |
+
# PyTorch expects True=ignore, so we flip
|
| 268 |
+
padding_mask = ~mask # (B, S), True = pad position to ignore
|
| 269 |
+
x = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
|
| 270 |
+
|
| 271 |
+
# 7. Post-attention GRN
|
| 272 |
+
x = self.post_attn_grn(x)
|
| 273 |
+
|
| 274 |
+
return x, temporal_weights
|
| 275 |
+
|
| 276 |
+
def forward(
|
| 277 |
+
self,
|
| 278 |
+
temporal_features: torch.Tensor, # (B, S, F_t)
|
| 279 |
+
static_features: torch.Tensor, # (B, F_s)
|
| 280 |
+
time_to_tca: torch.Tensor, # (B, S, 1)
|
| 281 |
+
mask: torch.Tensor, # (B, S) — True for real, False for padding
|
| 282 |
+
):
|
| 283 |
+
B, S, _ = temporal_features.shape
|
| 284 |
+
|
| 285 |
+
# Steps 1-7: encode sequence into per-timestep hidden states
|
| 286 |
+
x, temporal_weights = self.encode_sequence(
|
| 287 |
+
temporal_features, static_features, time_to_tca, mask
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# 8. Attention-weighted pooling over all real positions
|
| 291 |
+
# Instead of just the last CDM, learn which time steps matter most
|
| 292 |
+
attn_scores = self.pool_attention(x).squeeze(-1) # (B, S)
|
| 293 |
+
# Mask padding positions with -inf so they get zero attention
|
| 294 |
+
attn_scores = attn_scores.masked_fill(~mask, float("-inf"))
|
| 295 |
+
attn_weights = F.softmax(attn_scores, dim=-1) # (B, S)
|
| 296 |
+
# Handle all-padding edge case (shouldn't happen but be safe)
|
| 297 |
+
attn_weights = attn_weights.nan_to_num(0.0)
|
| 298 |
+
x_pooled = (x * attn_weights.unsqueeze(-1)).sum(dim=1) # (B, D)
|
| 299 |
+
|
| 300 |
+
# 9. Prediction heads
|
| 301 |
+
risk_logit = self.risk_head(x_pooled) # (B, 1)
|
| 302 |
+
miss_log = self.miss_head(x_pooled) # (B, 1)
|
| 303 |
+
pc_log10 = self.pc_head(x_pooled) # (B, 1) — log10(Pc)
|
| 304 |
+
|
| 305 |
+
return risk_logit, miss_log, pc_log10, temporal_weights
|
| 306 |
+
|
| 307 |
+
def count_parameters(self) -> int:
|
| 308 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class SigmoidFocalLoss(nn.Module):
|
| 312 |
+
"""
|
| 313 |
+
Focal Loss for binary classification (Lin et al., 2017).
|
| 314 |
+
|
| 315 |
+
Down-weights well-classified examples so the model focuses on hard cases.
|
| 316 |
+
FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)
|
| 317 |
+
|
| 318 |
+
With gamma=0, this reduces to standard weighted BCE.
|
| 319 |
+
With gamma=2, easy examples (p_t > 0.9) get ~100x less weight.
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
def __init__(self, alpha: float = 0.75, gamma: float = 2.0, reduction: str = "mean"):
|
| 323 |
+
super().__init__()
|
| 324 |
+
self.alpha = alpha
|
| 325 |
+
self.gamma = gamma
|
| 326 |
+
self.reduction = reduction
|
| 327 |
+
|
| 328 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 329 |
+
p = torch.sigmoid(logits)
|
| 330 |
+
# p_t = probability of the true class
|
| 331 |
+
p_t = targets * p + (1 - targets) * (1 - p)
|
| 332 |
+
# alpha_t = alpha for positive class, (1-alpha) for negative
|
| 333 |
+
alpha_t = targets * self.alpha + (1 - targets) * (1 - self.alpha)
|
| 334 |
+
# focal modulator: (1 - p_t)^gamma
|
| 335 |
+
focal_weight = (1 - p_t) ** self.gamma
|
| 336 |
+
# BCE per-element (numerically stable via log-sum-exp)
|
| 337 |
+
bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
|
| 338 |
+
loss = alpha_t * focal_weight * bce
|
| 339 |
+
if self.reduction == "none":
|
| 340 |
+
return loss
|
| 341 |
+
return loss.mean()
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class PhysicsInformedLoss(nn.Module):
|
| 345 |
+
"""
|
| 346 |
+
Combined task loss + physics regularization.
|
| 347 |
+
|
| 348 |
+
Total loss = risk_weight * FocalLoss(risk) + miss_weight * MSE(miss_distance)
|
| 349 |
+
+ pc_weight * MSE(log10_Pc) + physics_weight * ReLU(MOID - predicted_miss)
|
| 350 |
+
|
| 351 |
+
The physics term: MOID (Minimum Orbital Intersection Distance) is the
|
| 352 |
+
geometric minimum distance between two orbits. The actual miss distance
|
| 353 |
+
at closest approach CANNOT be less than MOID (without a maneuver).
|
| 354 |
+
If the model predicts miss < MOID, we penalize it.
|
| 355 |
+
|
| 356 |
+
The Pc term: direct regression on log10(collision probability). The Kelvins
|
| 357 |
+
`risk` column is log10(Pc), giving us 162K labeled examples. This lets
|
| 358 |
+
the model output calibrated collision probabilities, not just binary risk.
|
| 359 |
+
|
| 360 |
+
For the Kelvins dataset, we approximate MOID from the orbital elements
|
| 361 |
+
in the CDM features. When MOID isn't available, the physics term is 0.
|
| 362 |
+
"""
|
| 363 |
+
|
| 364 |
+
def __init__(
|
| 365 |
+
self,
|
| 366 |
+
risk_weight: float = 1.0,
|
| 367 |
+
miss_weight: float = 0.1,
|
| 368 |
+
pc_weight: float = 0.3,
|
| 369 |
+
physics_weight: float = 0.2,
|
| 370 |
+
pos_weight: float = 50.0,
|
| 371 |
+
use_focal: bool = False,
|
| 372 |
+
focal_alpha: float = 0.75,
|
| 373 |
+
focal_gamma: float = 2.0,
|
| 374 |
+
):
|
| 375 |
+
super().__init__()
|
| 376 |
+
self.risk_weight = risk_weight
|
| 377 |
+
self.miss_weight = miss_weight
|
| 378 |
+
self.pc_weight = pc_weight
|
| 379 |
+
self.physics_weight = physics_weight
|
| 380 |
+
if use_focal:
|
| 381 |
+
self.risk_loss = SigmoidFocalLoss(alpha=focal_alpha, gamma=focal_gamma)
|
| 382 |
+
else:
|
| 383 |
+
self.risk_loss = nn.BCEWithLogitsLoss(
|
| 384 |
+
pos_weight=torch.tensor(pos_weight)
|
| 385 |
+
)
|
| 386 |
+
self.miss_loss = nn.MSELoss()
|
| 387 |
+
|
| 388 |
+
def forward(
|
| 389 |
+
self,
|
| 390 |
+
risk_logit: torch.Tensor, # (B, 1)
|
| 391 |
+
miss_pred_log: torch.Tensor, # (B, 1)
|
| 392 |
+
risk_target: torch.Tensor, # (B,)
|
| 393 |
+
miss_target_log: torch.Tensor, # (B,)
|
| 394 |
+
pc_pred_log10: torch.Tensor = None, # (B, 1) predicted log10(Pc)
|
| 395 |
+
pc_target_log10: torch.Tensor = None, # (B,) target log10(Pc)
|
| 396 |
+
moid_log: torch.Tensor = None, # (B,) optional, log1p(MOID_km)
|
| 397 |
+
domain_weight: torch.Tensor = None, # (B,) per-sample weight
|
| 398 |
+
) -> tuple[torch.Tensor, dict]:
|
| 399 |
+
|
| 400 |
+
# Risk classification loss (BCE with class weighting)
|
| 401 |
+
if domain_weight is not None and not isinstance(self.risk_loss, SigmoidFocalLoss):
|
| 402 |
+
# Per-sample weighted BCE: compute element-wise then weight
|
| 403 |
+
bce_per_sample = F.binary_cross_entropy_with_logits(
|
| 404 |
+
risk_logit.squeeze(-1), risk_target,
|
| 405 |
+
pos_weight=self.risk_loss.pos_weight.to(risk_logit.device),
|
| 406 |
+
reduction="none",
|
| 407 |
+
)
|
| 408 |
+
L_risk = (bce_per_sample * domain_weight).mean()
|
| 409 |
+
else:
|
| 410 |
+
L_risk = self.risk_loss(risk_logit.squeeze(-1), risk_target)
|
| 411 |
+
|
| 412 |
+
# Miss distance regression loss — also domain-weighted
|
| 413 |
+
miss_residual = (miss_pred_log.squeeze(-1) - miss_target_log) ** 2
|
| 414 |
+
if domain_weight is not None:
|
| 415 |
+
L_miss = (miss_residual * domain_weight).mean()
|
| 416 |
+
else:
|
| 417 |
+
L_miss = miss_residual.mean()
|
| 418 |
+
|
| 419 |
+
# Collision probability regression loss
|
| 420 |
+
L_pc = torch.tensor(0.0, device=risk_logit.device)
|
| 421 |
+
if pc_pred_log10 is not None and pc_target_log10 is not None:
|
| 422 |
+
pc_residual = (pc_pred_log10.squeeze(-1) - pc_target_log10) ** 2
|
| 423 |
+
if domain_weight is not None:
|
| 424 |
+
L_pc = (pc_residual * domain_weight).mean()
|
| 425 |
+
else:
|
| 426 |
+
L_pc = pc_residual.mean()
|
| 427 |
+
|
| 428 |
+
# Physics constraint: predicted miss >= MOID
|
| 429 |
+
L_physics = torch.tensor(0.0, device=risk_logit.device)
|
| 430 |
+
if moid_log is not None:
|
| 431 |
+
# Violation = how much below MOID the prediction is
|
| 432 |
+
violation = F.relu(moid_log - miss_pred_log.squeeze(-1))
|
| 433 |
+
L_physics = violation.mean()
|
| 434 |
+
|
| 435 |
+
total = (self.risk_weight * L_risk
|
| 436 |
+
+ self.miss_weight * L_miss
|
| 437 |
+
+ self.pc_weight * L_pc
|
| 438 |
+
+ self.physics_weight * L_physics)
|
| 439 |
+
|
| 440 |
+
metrics = {
|
| 441 |
+
"loss": total.item(),
|
| 442 |
+
"risk_loss": L_risk.item(),
|
| 443 |
+
"miss_loss": L_miss.item(),
|
| 444 |
+
"pc_loss": L_pc.item(),
|
| 445 |
+
"physics_loss": L_physics.item(),
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
return total, metrics
|
src/model/pretrain.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code -- 2026-02-10
|
| 2 |
+
"""Self-supervised pre-training for the PI-TFT encoder.
|
| 3 |
+
|
| 4 |
+
Masked Feature Reconstruction: mask 60% of CDM temporal features at random
|
| 5 |
+
per timestep, train the Transformer encoder to reconstruct them. This forces
|
| 6 |
+
the model to learn feature correlations, temporal dynamics, and
|
| 7 |
+
static-temporal interactions from ALL CDM data (no labels needed).
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
from src.model.deep import PhysicsInformedTFT
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CDMMaskingStrategy(nn.Module):
|
| 17 |
+
"""Randomly mask temporal features per timestep for reconstruction pre-training.
|
| 18 |
+
|
| 19 |
+
For each real timestep (respecting padding mask), replaces a fraction of the
|
| 20 |
+
temporal features with a learnable [MASK] token.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, n_temporal_features: int, mask_ratio: float = 0.6):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.n_temporal_features = n_temporal_features
|
| 26 |
+
self.mask_ratio = mask_ratio
|
| 27 |
+
# Learnable [MASK] token — one value per temporal feature
|
| 28 |
+
self.mask_token = nn.Parameter(torch.zeros(n_temporal_features))
|
| 29 |
+
nn.init.normal_(self.mask_token, std=0.02)
|
| 30 |
+
|
| 31 |
+
def forward(
|
| 32 |
+
self,
|
| 33 |
+
temporal: torch.Tensor, # (B, S, F_t)
|
| 34 |
+
padding_mask: torch.Tensor, # (B, S) True=real, False=padding
|
| 35 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 36 |
+
"""Apply random feature masking.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
masked_temporal: (B, S, F_t) with masked positions replaced by mask_token
|
| 40 |
+
feature_mask: (B, S, F_t) bool — True where features were masked
|
| 41 |
+
"""
|
| 42 |
+
B, S, F = temporal.shape
|
| 43 |
+
|
| 44 |
+
# Generate random mask: True = masked (to reconstruct)
|
| 45 |
+
feature_mask = torch.rand(B, S, F, device=temporal.device) < self.mask_ratio
|
| 46 |
+
|
| 47 |
+
# Only mask real timesteps (not padding)
|
| 48 |
+
feature_mask = feature_mask & padding_mask.unsqueeze(-1)
|
| 49 |
+
|
| 50 |
+
# Replace masked positions with learnable mask token
|
| 51 |
+
masked_temporal = temporal.clone()
|
| 52 |
+
masked_temporal[feature_mask] = self.mask_token.expand(B, S, -1)[feature_mask]
|
| 53 |
+
|
| 54 |
+
return masked_temporal, feature_mask
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class MaskedReconstructionHead(nn.Module):
|
| 58 |
+
"""Lightweight 2-layer MLP decoder for feature reconstruction.
|
| 59 |
+
|
| 60 |
+
Intentionally small to force the encoder (not the decoder) to learn
|
| 61 |
+
rich representations.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, d_model: int, n_temporal_features: int, dropout: float = 0.1):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.decoder = nn.Sequential(
|
| 67 |
+
nn.LayerNorm(d_model),
|
| 68 |
+
nn.Linear(d_model, d_model),
|
| 69 |
+
nn.GELU(),
|
| 70 |
+
nn.Dropout(dropout),
|
| 71 |
+
nn.Linear(d_model, n_temporal_features),
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def forward(self, hidden: torch.Tensor) -> torch.Tensor:
|
| 75 |
+
"""Reconstruct temporal features from encoder hidden states.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
hidden: (B, S, D) per-timestep encoder output
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
reconstructed: (B, S, F_t) reconstructed temporal features
|
| 82 |
+
"""
|
| 83 |
+
return self.decoder(hidden)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class PretrainingWrapper(nn.Module):
|
| 87 |
+
"""Wraps PI-TFT encoder with masking strategy and reconstruction head.
|
| 88 |
+
|
| 89 |
+
Forward pass: generate mask → apply mask token → encode_sequence() →
|
| 90 |
+
reconstruct → return reconstructed + masks.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
n_temporal_features: int,
|
| 96 |
+
n_static_features: int,
|
| 97 |
+
d_model: int = 128,
|
| 98 |
+
n_heads: int = 4,
|
| 99 |
+
n_layers: int = 2,
|
| 100 |
+
dropout: float = 0.15,
|
| 101 |
+
mask_ratio: float = 0.6,
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.encoder = PhysicsInformedTFT(
|
| 105 |
+
n_temporal_features=n_temporal_features,
|
| 106 |
+
n_static_features=n_static_features,
|
| 107 |
+
d_model=d_model,
|
| 108 |
+
n_heads=n_heads,
|
| 109 |
+
n_layers=n_layers,
|
| 110 |
+
dropout=dropout,
|
| 111 |
+
)
|
| 112 |
+
self.masking = CDMMaskingStrategy(n_temporal_features, mask_ratio)
|
| 113 |
+
self.reconstruction_head = MaskedReconstructionHead(
|
| 114 |
+
d_model, n_temporal_features, dropout
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def forward(
|
| 118 |
+
self,
|
| 119 |
+
temporal: torch.Tensor, # (B, S, F_t)
|
| 120 |
+
static: torch.Tensor, # (B, F_s)
|
| 121 |
+
time_to_tca: torch.Tensor, # (B, S, 1)
|
| 122 |
+
mask: torch.Tensor, # (B, S) True=real
|
| 123 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 124 |
+
"""
|
| 125 |
+
Returns:
|
| 126 |
+
reconstructed: (B, S, F_t) reconstructed temporal features
|
| 127 |
+
feature_mask: (B, S, F_t) bool — True where features were masked
|
| 128 |
+
original: (B, S, F_t) original temporal features (for loss computation)
|
| 129 |
+
"""
|
| 130 |
+
original = temporal.clone()
|
| 131 |
+
|
| 132 |
+
# Mask temporal features
|
| 133 |
+
masked_temporal, feature_mask = self.masking(temporal, mask)
|
| 134 |
+
|
| 135 |
+
# Encode masked sequence
|
| 136 |
+
hidden, _ = self.encoder.encode_sequence(
|
| 137 |
+
masked_temporal, static, time_to_tca, mask
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Reconstruct
|
| 141 |
+
reconstructed = self.reconstruction_head(hidden)
|
| 142 |
+
|
| 143 |
+
return reconstructed, feature_mask, original
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class PretrainingLoss(nn.Module):
|
| 147 |
+
"""MSE loss computed only on masked positions."""
|
| 148 |
+
|
| 149 |
+
def forward(
|
| 150 |
+
self,
|
| 151 |
+
reconstructed: torch.Tensor, # (B, S, F_t)
|
| 152 |
+
original: torch.Tensor, # (B, S, F_t)
|
| 153 |
+
feature_mask: torch.Tensor, # (B, S, F_t) bool
|
| 154 |
+
) -> tuple[torch.Tensor, dict]:
|
| 155 |
+
# MSE on masked positions only
|
| 156 |
+
masked_diff = (reconstructed - original) ** 2
|
| 157 |
+
masked_diff = masked_diff[feature_mask]
|
| 158 |
+
|
| 159 |
+
if masked_diff.numel() == 0:
|
| 160 |
+
loss = torch.tensor(0.0, device=reconstructed.device, requires_grad=True)
|
| 161 |
+
else:
|
| 162 |
+
loss = masked_diff.mean()
|
| 163 |
+
|
| 164 |
+
return loss, {"reconstruction_loss": loss.item()}
|
src/model/triage.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Claude Code -- 2026-02-13
|
| 2 |
+
"""Urgency tier classifier for conjunction events."""
|
| 3 |
+
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class UrgencyTier(str, Enum):
|
| 9 |
+
LOW = "LOW"
|
| 10 |
+
MODERATE = "MODERATE"
|
| 11 |
+
HIGH = "HIGH"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class TriageResult:
|
| 16 |
+
tier: UrgencyTier
|
| 17 |
+
color: str
|
| 18 |
+
recommendation: str
|
| 19 |
+
risk_probability: float
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def classify_urgency(risk_prob: float) -> TriageResult:
|
| 23 |
+
"""Classify conjunction urgency based on predicted risk probability.
|
| 24 |
+
|
| 25 |
+
Tiers:
|
| 26 |
+
LOW (risk <= 0.10): Monitor only
|
| 27 |
+
MODERATE (0.10 < risk <= 0.40): Assess maneuver options
|
| 28 |
+
HIGH (risk > 0.40): Immediate action required
|
| 29 |
+
"""
|
| 30 |
+
if risk_prob <= 0.10:
|
| 31 |
+
return TriageResult(
|
| 32 |
+
tier=UrgencyTier.LOW,
|
| 33 |
+
color="#4fff8a",
|
| 34 |
+
recommendation="Monitor conjunction. No action required.",
|
| 35 |
+
risk_probability=risk_prob,
|
| 36 |
+
)
|
| 37 |
+
elif risk_prob <= 0.40:
|
| 38 |
+
return TriageResult(
|
| 39 |
+
tier=UrgencyTier.MODERATE,
|
| 40 |
+
color="#ffb84f",
|
| 41 |
+
recommendation="Assess maneuver options. Increased monitoring recommended.",
|
| 42 |
+
risk_probability=risk_prob,
|
| 43 |
+
)
|
| 44 |
+
else:
|
| 45 |
+
return TriageResult(
|
| 46 |
+
tier=UrgencyTier.HIGH,
|
| 47 |
+
color="#ff4f5a",
|
| 48 |
+
recommendation="Immediate action required. Initiate collision avoidance maneuver.",
|
| 49 |
+
risk_probability=risk_prob,
|
| 50 |
+
)
|
src/utils/__init__.py
ADDED
|
File without changes
|