Spaces:
Sleeping
Sleeping
| # Generated by Claude Code -- 2026-02-13 | |
| """FastAPI backend for Panacea collision avoidance inference.""" | |
| import json | |
| import os | |
| import numpy as np | |
| import torch | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| from typing import Optional | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import sys | |
| ROOT = Path(__file__).parent.parent | |
| sys.path.insert(0, str(ROOT)) | |
| from src.model.baseline import OrbitalShellBaseline | |
| from src.model.classical import XGBoostConjunctionModel | |
| from src.model.deep import PhysicsInformedTFT | |
| from src.model.triage import classify_urgency | |
| from src.data.sequence_builder import TEMPORAL_FEATURES, STATIC_FEATURES, MAX_SEQ_LEN | |
| HF_REPO_ID = "DTanzillo/panacea-models" | |
| # Global model storage | |
| models = {} | |
| def download_models_from_hf(model_dir: Path, results_dir: Path): | |
| """Download models from HuggingFace Hub if not available locally.""" | |
| try: | |
| from huggingface_hub import snapshot_download | |
| token = os.environ.get("HF_TOKEN") | |
| local = snapshot_download( | |
| HF_REPO_ID, | |
| token=token, | |
| allow_patterns=["models/*", "results/*"], | |
| ) | |
| local = Path(local) | |
| # Copy files to expected locations | |
| for src in (local / "models").iterdir(): | |
| dst = model_dir / src.name | |
| if not dst.exists(): | |
| import shutil | |
| shutil.copy2(src, dst) | |
| print(f" Downloaded {src.name} from HF Hub") | |
| for src in (local / "results").iterdir(): | |
| dst = results_dir / src.name | |
| if not dst.exists(): | |
| import shutil | |
| shutil.copy2(src, dst) | |
| print(f" Downloaded {src.name} from HF Hub") | |
| except Exception as e: | |
| print(f" HF Hub download skipped: {e}") | |
| def load_models(): | |
| """Load all 3 models at startup. Downloads from HF Hub if missing.""" | |
| model_dir = ROOT / "models" | |
| results_dir = ROOT / "results" | |
| model_dir.mkdir(exist_ok=True) | |
| results_dir.mkdir(exist_ok=True) | |
| # Try downloading from HF Hub if local models are missing | |
| if not (model_dir / "baseline.json").exists(): | |
| print(" Local models not found, trying HuggingFace Hub...") | |
| download_models_from_hf(model_dir, results_dir) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| baseline_path = model_dir / "baseline.json" | |
| if baseline_path.exists(): | |
| models["baseline"] = OrbitalShellBaseline.load(baseline_path) | |
| print(" Loaded baseline model") | |
| xgboost_path = model_dir / "xgboost.pkl" | |
| if xgboost_path.exists(): | |
| models["xgboost"] = XGBoostConjunctionModel.load(xgboost_path) | |
| print(" Loaded XGBoost model") | |
| pitft_path = model_dir / "transformer.pt" | |
| if pitft_path.exists(): | |
| checkpoint = torch.load(pitft_path, map_location=device, weights_only=False) | |
| config = checkpoint["config"] | |
| model = PhysicsInformedTFT( | |
| n_temporal_features=config["n_temporal"], | |
| n_static_features=config["n_static"], | |
| d_model=config.get("d_model", 128), | |
| n_heads=config.get("n_heads", 4), | |
| n_layers=config.get("n_layers", 2), | |
| ).to(device) | |
| # strict=False for backward compat: old checkpoints lack pc_head weights | |
| model.load_state_dict(checkpoint["model_state"], strict=False) | |
| model.eval() | |
| models["pitft"] = model | |
| models["pitft_checkpoint"] = checkpoint | |
| models["pitft_device"] = device | |
| temp = checkpoint.get("temperature", 1.0) | |
| has_pc = checkpoint.get("has_pc_head", False) | |
| print(f" Loaded PI-TFT (epoch {checkpoint['epoch']}, T={temp:.3f}, pc_head={'yes' if has_pc else 'no'})") | |
| async def lifespan(app: FastAPI): | |
| print("Loading models ...") | |
| load_models() | |
| loaded = [k for k in models if not k.startswith("pitft_")] | |
| print(f"Models loaded: {loaded}") | |
| yield | |
| models.clear() | |
| app = FastAPI( | |
| title="Panacea — Satellite Collision Avoidance API", | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- Pydantic models --- | |
| class CDMFeatures(BaseModel): | |
| """A sequence of CDM feature snapshots for one conjunction event.""" | |
| event_id: Optional[int] = None | |
| cdm_sequence: list[dict] | |
| class BulkScreenRequest(BaseModel): | |
| """TLE data for pairwise screening.""" | |
| tles: list[dict] | |
| top_k: int = 10 | |
| # --- Endpoints --- | |
| async def health(): | |
| loaded = [] | |
| if "baseline" in models: | |
| loaded.append("baseline") | |
| if "xgboost" in models: | |
| loaded.append("xgboost") | |
| if "pitft" in models: | |
| loaded.append("pitft") | |
| device = str(models.get("pitft_device", "cpu")) | |
| return { | |
| "status": "healthy", | |
| "models_loaded": loaded, | |
| "device": device, | |
| "n_models": len(loaded), | |
| } | |
| async def predict_conjunction(features: CDMFeatures): | |
| """Run inference on a single conjunction event across all loaded models.""" | |
| results = {} | |
| cdm_seq = features.cdm_sequence | |
| if not cdm_seq: | |
| return {"error": "Empty CDM sequence"} | |
| last_cdm = cdm_seq[-1] | |
| altitude = last_cdm.get("t_h_apo", last_cdm.get("c_h_apo", 500.0)) | |
| # Baseline prediction | |
| if "baseline" in models: | |
| risk_probs, miss_preds = models["baseline"].predict(np.array([altitude])) | |
| triage = classify_urgency(float(risk_probs[0])) | |
| results["baseline"] = { | |
| "risk_probability": float(risk_probs[0]), | |
| "miss_distance_km": float(np.expm1(miss_preds[0])), | |
| "triage": { | |
| "tier": triage.tier.value, | |
| "color": triage.color, | |
| "recommendation": triage.recommendation, | |
| }, | |
| } | |
| # XGBoost prediction | |
| if "xgboost" in models: | |
| xgb_features = _build_xgboost_features(cdm_seq) | |
| risk_probs, miss_km = models["xgboost"].predict(xgb_features) | |
| triage = classify_urgency(float(risk_probs[0])) | |
| results["xgboost"] = { | |
| "risk_probability": float(risk_probs[0]), | |
| "miss_distance_km": float(miss_km[0]), | |
| "triage": { | |
| "tier": triage.tier.value, | |
| "color": triage.color, | |
| "recommendation": triage.recommendation, | |
| }, | |
| } | |
| # PI-TFT prediction | |
| if "pitft" in models: | |
| risk_prob, miss_log, pc_log10 = _run_pitft_inference(cdm_seq) | |
| triage = classify_urgency(risk_prob) | |
| results["pitft"] = { | |
| "risk_probability": risk_prob, | |
| "miss_distance_km": float(np.expm1(miss_log)), | |
| "collision_probability": float(10 ** pc_log10), | |
| "collision_probability_log10": pc_log10, | |
| "triage": { | |
| "tier": triage.tier.value, | |
| "color": triage.color, | |
| "recommendation": triage.recommendation, | |
| }, | |
| } | |
| return results | |
| async def model_comparison(): | |
| """Return pre-computed model comparison results.""" | |
| results = [] | |
| comparison_path = ROOT / "results" / "model_comparison.json" | |
| if comparison_path.exists(): | |
| with open(comparison_path) as f: | |
| results = json.load(f) | |
| deep_path = ROOT / "results" / "deep_model_results.json" | |
| if deep_path.exists(): | |
| with open(deep_path) as f: | |
| deep = json.load(f) | |
| pitft_entry = { | |
| "model": deep["model"], | |
| **deep["test"], | |
| } | |
| results.append(pitft_entry) | |
| return results | |
| async def experiment_results(): | |
| """Return staleness experiment results.""" | |
| exp_path = ROOT / "results" / "staleness_experiment.json" | |
| if exp_path.exists(): | |
| with open(exp_path) as f: | |
| return json.load(f) | |
| return {"error": "No experiment results found. Run: python scripts/run_experiment.py"} | |
| async def bulk_screen(request: BulkScreenRequest): | |
| """Screen TLE pairs for potential conjunctions using orbital filtering.""" | |
| tles = request.tles | |
| top_k = request.top_k | |
| if len(tles) < 2: | |
| return {"pairs": [], "n_candidates": 0, "n_total": len(tles)} | |
| n = len(tles) | |
| names = [t.get("OBJECT_NAME", f"Object {i}") for i, t in enumerate(tles)] | |
| norad_ids = [t.get("NORAD_CAT_ID", 0) for t in tles] | |
| # Compute altitude from mean motion: a = (mu / n^2)^(1/3), alt = a - R_earth | |
| MU = 398600.4418 # km^3/s^2 | |
| R_EARTH = 6371.0 # km | |
| mean_motions = np.array([t.get("MEAN_MOTION", 15.0) for t in tles]) | |
| n_rad = mean_motions * 2 * np.pi / 86400.0 | |
| n_rad = np.clip(n_rad, 1e-10, None) | |
| sma = (MU / (n_rad ** 2)) ** (1.0 / 3.0) | |
| eccentricities = np.array([t.get("ECCENTRICITY", 0.0) for t in tles]) | |
| apogee = sma * (1 + eccentricities) - R_EARTH | |
| perigee = sma * (1 - eccentricities) - R_EARTH | |
| raan = np.array([t.get("RA_OF_ASC_NODE", 0.0) for t in tles]) | |
| # Pairwise filtering via broadcasting | |
| alt_overlap = ((apogee[:, None] >= perigee[None, :]) & | |
| (apogee[None, :] >= perigee[:, None])) | |
| raan_diff = np.abs(raan[:, None] - raan[None, :]) | |
| raan_diff = np.minimum(raan_diff, 360.0 - raan_diff) | |
| raan_close = raan_diff < 30.0 | |
| candidates = alt_overlap & raan_close | |
| np.fill_diagonal(candidates, False) | |
| candidates = np.triu(candidates, k=1) | |
| pairs_i, pairs_j = np.where(candidates) | |
| if len(pairs_i) == 0: | |
| return {"pairs": [], "n_candidates": 0, "n_total": n} | |
| # Score candidates using baseline model | |
| if "baseline" in models: | |
| pair_altitudes = (apogee[pairs_i] + apogee[pairs_j]) / 2.0 | |
| risk_scores, miss_estimates = models["baseline"].predict(pair_altitudes) | |
| else: | |
| risk_scores = np.ones(len(pairs_i)) * 0.5 | |
| miss_estimates = np.zeros(len(pairs_i)) | |
| top_indices = np.argsort(-risk_scores)[:top_k] | |
| result_pairs = [] | |
| for idx in top_indices: | |
| i, j = int(pairs_i[idx]), int(pairs_j[idx]) | |
| result_pairs.append({ | |
| "name_1": names[i], | |
| "name_2": names[j], | |
| "norad_1": norad_ids[i], | |
| "norad_2": norad_ids[j], | |
| "risk_score": float(risk_scores[idx]), | |
| "altitude_km": float((apogee[i] + apogee[j]) / 2), | |
| "miss_estimate_km": (float(np.expm1(miss_estimates[idx])) | |
| if miss_estimates[idx] > 0 else 0.0), | |
| }) | |
| return { | |
| "pairs": result_pairs, | |
| "n_candidates": int(len(pairs_i)), | |
| "n_total": n, | |
| } | |
| # --- Helper functions --- | |
| def _build_xgboost_features(cdm_sequence: list[dict]) -> np.ndarray: | |
| """Build XGBoost feature vector from a CDM sequence (dict format). | |
| Replicates events_to_flat_features() logic for a single event. | |
| """ | |
| last = cdm_sequence[-1] | |
| exclude = {"event_id", "time_to_tca", "risk", "mission_id"} | |
| feature_keys = sorted([ | |
| k for k in last.keys() | |
| if isinstance(last.get(k), (int, float)) and k not in exclude | |
| ]) | |
| base = np.array([float(last.get(k, 0.0)) for k in feature_keys], dtype=np.float32) | |
| miss_values = np.array([float(s.get("miss_distance", 0.0)) for s in cdm_sequence]) | |
| risk_values = np.array([float(s.get("risk", -10.0)) for s in cdm_sequence]) | |
| tca_values = np.array([float(s.get("time_to_tca", 0.0)) for s in cdm_sequence]) | |
| n_cdms = len(cdm_sequence) | |
| miss_mean = float(np.mean(miss_values)) | |
| miss_std = float(np.std(miss_values)) if n_cdms > 1 else 0.0 | |
| miss_trend = 0.0 | |
| if n_cdms > 1 and np.std(tca_values) > 0: | |
| miss_trend = float(np.polyfit(tca_values, miss_values, 1)[0]) | |
| risk_trend = 0.0 | |
| if n_cdms > 1 and np.std(tca_values) > 0: | |
| risk_trend = float(np.polyfit(tca_values, risk_values, 1)[0]) | |
| temporal_feats = np.array([ | |
| n_cdms, | |
| miss_mean, | |
| miss_std, | |
| miss_trend, | |
| risk_trend, | |
| float(miss_values[0] - miss_values[-1]) if n_cdms > 1 else 0.0, | |
| float(last.get("time_to_tca", 0.0)), | |
| float(last.get("relative_speed", 0.0)), | |
| ], dtype=np.float32) | |
| combined = np.concatenate([base, temporal_feats]) | |
| combined = np.nan_to_num(combined, nan=0.0, posinf=0.0, neginf=0.0) | |
| X = combined.reshape(1, -1) | |
| # Pad features if model was trained on augmented data with more columns | |
| if "xgboost" in models: | |
| expected = models["xgboost"].scaler.n_features_in_ | |
| if X.shape[1] < expected: | |
| padding = np.zeros((X.shape[0], expected - X.shape[1]), dtype=X.dtype) | |
| X = np.hstack([X, padding]) | |
| elif X.shape[1] > expected: | |
| X = X[:, :expected] | |
| return X | |
| def _run_pitft_inference(cdm_sequence: list[dict]) -> tuple[float, float, float]: | |
| """Run PI-TFT inference on a single CDM sequence. | |
| Returns: (risk_probability, miss_log) | |
| """ | |
| checkpoint = models["pitft_checkpoint"] | |
| device = models["pitft_device"] | |
| model = models["pitft"] | |
| norm = checkpoint["normalization"] | |
| temperature = checkpoint.get("temperature", 1.0) | |
| temporal_cols = checkpoint.get("temporal_cols", TEMPORAL_FEATURES) | |
| static_cols = checkpoint.get("static_cols", STATIC_FEATURES) | |
| # Extract temporal features: (S, F_t) | |
| temporal = np.array([ | |
| [float(cdm.get(col, 0.0)) for col in temporal_cols] | |
| for cdm in cdm_sequence | |
| ], dtype=np.float32) | |
| temporal = np.nan_to_num(temporal, nan=0.0, posinf=0.0, neginf=0.0) | |
| # Compute deltas | |
| if len(temporal) > 1: | |
| deltas = np.diff(temporal, axis=0) | |
| deltas = np.concatenate( | |
| [np.zeros((1, deltas.shape[1]), dtype=np.float32), deltas], axis=0 | |
| ) | |
| else: | |
| deltas = np.zeros_like(temporal) | |
| # Normalize | |
| t_mean = np.array(norm["temporal_mean"], dtype=np.float32) | |
| t_std = np.array(norm["temporal_std"], dtype=np.float32) | |
| d_mean = np.array(norm["delta_mean"], dtype=np.float32) | |
| d_std = np.array(norm["delta_std"], dtype=np.float32) | |
| s_mean = np.array(norm["static_mean"], dtype=np.float32) | |
| s_std = np.array(norm["static_std"], dtype=np.float32) | |
| temporal = (temporal - t_mean) / t_std | |
| deltas = (deltas - d_mean) / d_std | |
| temporal = np.concatenate([temporal, deltas], axis=1) | |
| # Static features from last CDM | |
| last_cdm = cdm_sequence[-1] | |
| static = np.array( | |
| [float(last_cdm.get(col, 0.0)) for col in static_cols], dtype=np.float32 | |
| ) | |
| static = np.nan_to_num(static, nan=0.0, posinf=0.0, neginf=0.0) | |
| static = (static - s_mean) / s_std | |
| # Time-to-TCA | |
| tca_mean = norm["tca_mean"] | |
| tca_std = norm["tca_std"] | |
| tca = np.array( | |
| [float(cdm.get("time_to_tca", 0.0)) for cdm in cdm_sequence], dtype=np.float32 | |
| ).reshape(-1, 1) | |
| tca = (tca - tca_mean) / tca_std | |
| # Pad/truncate to MAX_SEQ_LEN | |
| seq_len = len(temporal) | |
| if seq_len > MAX_SEQ_LEN: | |
| temporal = temporal[-MAX_SEQ_LEN:] | |
| tca = tca[-MAX_SEQ_LEN:] | |
| seq_len = MAX_SEQ_LEN | |
| pad_len = MAX_SEQ_LEN - seq_len | |
| if pad_len > 0: | |
| temporal = np.pad(temporal, ((pad_len, 0), (0, 0)), constant_values=0) | |
| tca = np.pad(tca, ((pad_len, 0), (0, 0)), constant_values=0) | |
| mask = np.zeros(MAX_SEQ_LEN, dtype=bool) | |
| mask[pad_len:] = True | |
| # Convert to tensors | |
| temporal_t = torch.tensor(temporal, dtype=torch.float32).unsqueeze(0).to(device) | |
| static_t = torch.tensor(static, dtype=torch.float32).unsqueeze(0).to(device) | |
| tca_t = torch.tensor(tca, dtype=torch.float32).unsqueeze(0).to(device) | |
| mask_t = torch.tensor(mask, dtype=torch.bool).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| risk_logit, miss_log, pc_log10, _ = model(temporal_t, static_t, tca_t, mask_t) | |
| risk_prob = float(torch.sigmoid(risk_logit / temperature).cpu().item()) | |
| miss_log_val = float(miss_log.cpu().item()) | |
| pc_log10_val = float(pc_log10.cpu().item()) | |
| return risk_prob, miss_log_val, pc_log10_val | |