# Generated by Claude Code -- 2026-02-13 """TLE Staleness Sensitivity Experiment. Evaluates how model performance degrades as CDM data becomes stale. Simulates staleness by filtering CDM sequences to only include updates received at least `cutoff_days` before TCA. The Kelvins test set has time_to_tca in [2.0, 7.0] days, so meaningful cutoffs are in that range. A cutoff of 2.0 keeps all data (baseline), while a cutoff of 6.0 keeps only the earliest CDMs. Ground-truth labels always come from the ORIGINAL (untruncated) test set — we're measuring how well models predict with less-recent information. """ import numpy as np import pandas as pd import torch from torch.utils.data import DataLoader from src.data.cdm_loader import build_events, events_to_flat_features, get_feature_columns from src.data.sequence_builder import CDMSequenceDataset from src.evaluation.metrics import evaluate_risk # Staleness cutoffs (days before TCA) # 2.0 = keep all data (baseline), 6.0 = only very early CDMs DEFAULT_CUTOFFS = [2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 6.0] QUICK_CUTOFFS = [2.0, 4.0, 6.0] def truncate_cdm_dataframe(df: pd.DataFrame, cutoff_days: float) -> pd.DataFrame: """Filter CDM rows to only those with time_to_tca >= cutoff_days. Simulates data staleness: if cutoff=4.0, the model only sees CDMs that arrived 4+ days before closest approach. """ return df[df["time_to_tca"] >= cutoff_days].copy() def get_ground_truth_labels(df: pd.DataFrame) -> dict: """Extract per-event ground truth labels from the FULL (untruncated) dataset. Labels come from the final CDM per event (closest to TCA). Returns: {event_id: {"risk_label": int, "miss_log": float, "altitude_km": float}} """ labels = {} for event_id, group in df.groupby("event_id"): group = group.sort_values("time_to_tca", ascending=True) final = group.iloc[0] risk_label = 1 if final["risk"] > -5 else 0 miss_log = float(np.log1p(max(final.get("miss_distance", 0.0), 0.0))) alt = float(final.get("t_h_apo", 0.0)) labels[int(event_id)] = { "risk_label": risk_label, "miss_log": miss_log, "altitude_km": alt, } return labels def evaluate_baseline_at_cutoff(baseline_model, ground_truth: dict, cutoff: float) -> dict: """Evaluate baseline model. Uses altitude only, unaffected by staleness.""" altitudes = np.array([gt["altitude_km"] for gt in ground_truth.values()]) y_true = np.array([gt["risk_label"] for gt in ground_truth.values()]) risk_probs, _ = baseline_model.predict(altitudes) metrics = evaluate_risk(y_true, risk_probs) metrics["cutoff"] = cutoff metrics["n_events"] = len(y_true) return metrics def evaluate_xgboost_at_cutoff( xgboost_model, truncated_df: pd.DataFrame, ground_truth: dict, feature_cols: list[str], cutoff: float, ) -> dict: """Evaluate XGBoost on truncated CDM data.""" events = build_events(truncated_df, feature_cols) if len(events) == 0: return {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff} X, _, _ = events_to_flat_features(events) # Pad features if model was trained on augmented data with more columns expected_features = xgboost_model.scaler.n_features_in_ if X.shape[1] < expected_features: padding = np.zeros((X.shape[0], expected_features - X.shape[1]), dtype=X.dtype) X = np.hstack([X, padding]) event_ids = [e.event_id for e in events] valid_mask = np.array([eid in ground_truth for eid in event_ids]) X = X[valid_mask] valid_ids = [eid for eid in event_ids if eid in ground_truth] y_true = np.array([ground_truth[eid]["risk_label"] for eid in valid_ids]) if len(y_true) == 0 or y_true.sum() == 0: return {"auc_pr": 0.0, "f1": 0.0, "n_events": len(y_true), "cutoff": cutoff} # Pad features if model expects more (e.g., trained on augmented data) expected = xgboost_model.scaler.n_features_in_ if X.shape[1] < expected: pad_width = expected - X.shape[1] X = np.pad(X, ((0, 0), (0, pad_width)), constant_values=0) elif X.shape[1] > expected: X = X[:, :expected] risk_probs = xgboost_model.predict_risk(X) metrics = evaluate_risk(y_true, risk_probs) metrics["cutoff"] = cutoff metrics["n_events"] = len(y_true) return metrics def evaluate_pitft_at_cutoff( model, truncated_df: pd.DataFrame, ground_truth: dict, train_ds: CDMSequenceDataset, device: torch.device, temperature: float = 1.0, cutoff: float = 0.0, batch_size: int = 128, ) -> dict: """Evaluate PI-TFT on truncated CDM data with temperature scaling.""" # Ensure all required columns exist (pad missing with 0) df = truncated_df.copy() for col in train_ds.temporal_cols + train_ds.static_cols: if col not in df.columns: df[col] = 0.0 test_ds = CDMSequenceDataset( df, temporal_cols=train_ds.temporal_cols, static_cols=train_ds.static_cols, ) test_ds.set_normalization(train_ds) if len(test_ds) == 0: return {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff} # Get event IDs from the dataset event_ids = [e["event_id"] for e in test_ds.events] loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0) model.eval() all_probs = [] with torch.no_grad(): for batch in loader: temporal = batch["temporal"].to(device) static = batch["static"].to(device) tca = batch["time_to_tca"].to(device) mask = batch["mask"].to(device) risk_logit, _, _, _ = model(temporal, static, tca, mask) probs = torch.sigmoid(risk_logit / temperature).cpu().numpy().flatten() all_probs.append(probs) risk_probs = np.concatenate(all_probs) # Match predictions to ground truth valid_mask = np.array([eid in ground_truth for eid in event_ids]) risk_probs = risk_probs[valid_mask] valid_ids = [eid for eid in event_ids if eid in ground_truth] y_true = np.array([ground_truth[eid]["risk_label"] for eid in valid_ids]) if len(y_true) == 0 or y_true.sum() == 0: return {"auc_pr": 0.0, "f1": 0.0, "n_events": len(y_true), "cutoff": cutoff} metrics = evaluate_risk(y_true, risk_probs) metrics["cutoff"] = cutoff metrics["n_events"] = int(len(y_true)) return metrics def run_staleness_experiment( baseline_model, xgboost_model, pitft_model, pitft_checkpoint: dict, test_df: pd.DataFrame, train_ds: CDMSequenceDataset, feature_cols: list[str], device: torch.device, cutoffs: list[float] = None, quick: bool = False, ) -> dict: """Run the full staleness experiment across all cutoffs and models. Args: baseline_model: OrbitalShellBaseline instance xgboost_model: XGBoostConjunctionModel instance pitft_model: PhysicsInformedTFT (eval mode), or None to skip pitft_checkpoint: checkpoint dict with temperature test_df: ORIGINAL (untruncated) test DataFrame train_ds: CDMSequenceDataset from training data (for normalization) feature_cols: list of feature column names for XGBoost device: torch device cutoffs: list of staleness cutoffs (days before TCA) quick: if True, use fewer cutoffs """ if cutoffs is None: cutoffs = QUICK_CUTOFFS if quick else DEFAULT_CUTOFFS ground_truth = get_ground_truth_labels(test_df) n_pos = sum(1 for gt in ground_truth.values() if gt["risk_label"] == 1) print(f"\nGround truth: {len(ground_truth)} events, {n_pos} positive") temperature = 1.0 if pitft_checkpoint: temperature = pitft_checkpoint.get("temperature", 1.0) results = { "cutoffs": cutoffs, "n_test_events": len(ground_truth), "n_positive": n_pos, "baseline": [], "xgboost": [], "pitft": [], } for cutoff in cutoffs: print(f"\n{'='*50}") print(f"Staleness cutoff: {cutoff:.1f} days") print(f"{'='*50}") truncated = truncate_cdm_dataframe(test_df, cutoff) n_events = truncated["event_id"].nunique() n_rows = len(truncated) print(f" Surviving: {n_events} events, {n_rows} CDMs") # Baseline (uses altitude only — constant across cutoffs) bl = evaluate_baseline_at_cutoff(baseline_model, ground_truth, cutoff) results["baseline"].append(bl) print(f" Baseline AUC-PR={bl.get('auc_pr', 0):.4f}, F1={bl.get('f1', 0):.4f}") # XGBoost if n_events > 0: xgb = evaluate_xgboost_at_cutoff( xgboost_model, truncated, ground_truth, feature_cols, cutoff ) else: xgb = {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff} results["xgboost"].append(xgb) print(f" XGBoost AUC-PR={xgb.get('auc_pr', 0):.4f}, " f"F1={xgb.get('f1', 0):.4f} ({xgb.get('n_events', 0)} events)") # PI-TFT if n_events > 0 and pitft_model is not None: tft = evaluate_pitft_at_cutoff( pitft_model, truncated, ground_truth, train_ds, device, temperature=temperature, cutoff=cutoff, ) else: tft = {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff} results["pitft"].append(tft) print(f" PI-TFT AUC-PR={tft.get('auc_pr', 0):.4f}, " f"F1={tft.get('f1', 0):.4f}") return results