Spaces:
Sleeping
Sleeping
| # Generated by Claude Code -- 2026-02-13 | |
| """TLE Staleness Sensitivity Experiment. | |
| Evaluates how model performance degrades as CDM data becomes stale. | |
| Simulates staleness by filtering CDM sequences to only include updates | |
| received at least `cutoff_days` before TCA. | |
| The Kelvins test set has time_to_tca in [2.0, 7.0] days, so meaningful | |
| cutoffs are in that range. A cutoff of 2.0 keeps all data (baseline), | |
| while a cutoff of 6.0 keeps only the earliest CDMs. | |
| Ground-truth labels always come from the ORIGINAL (untruncated) test set — | |
| we're measuring how well models predict with less-recent information. | |
| """ | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from src.data.cdm_loader import build_events, events_to_flat_features, get_feature_columns | |
| from src.data.sequence_builder import CDMSequenceDataset | |
| from src.evaluation.metrics import evaluate_risk | |
| # Staleness cutoffs (days before TCA) | |
| # 2.0 = keep all data (baseline), 6.0 = only very early CDMs | |
| DEFAULT_CUTOFFS = [2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 6.0] | |
| QUICK_CUTOFFS = [2.0, 4.0, 6.0] | |
| def truncate_cdm_dataframe(df: pd.DataFrame, cutoff_days: float) -> pd.DataFrame: | |
| """Filter CDM rows to only those with time_to_tca >= cutoff_days. | |
| Simulates data staleness: if cutoff=4.0, the model only sees CDMs | |
| that arrived 4+ days before closest approach. | |
| """ | |
| return df[df["time_to_tca"] >= cutoff_days].copy() | |
| def get_ground_truth_labels(df: pd.DataFrame) -> dict: | |
| """Extract per-event ground truth labels from the FULL (untruncated) dataset. | |
| Labels come from the final CDM per event (closest to TCA). | |
| Returns: {event_id: {"risk_label": int, "miss_log": float, "altitude_km": float}} | |
| """ | |
| labels = {} | |
| for event_id, group in df.groupby("event_id"): | |
| group = group.sort_values("time_to_tca", ascending=True) | |
| final = group.iloc[0] | |
| risk_label = 1 if final["risk"] > -5 else 0 | |
| miss_log = float(np.log1p(max(final.get("miss_distance", 0.0), 0.0))) | |
| alt = float(final.get("t_h_apo", 0.0)) | |
| labels[int(event_id)] = { | |
| "risk_label": risk_label, | |
| "miss_log": miss_log, | |
| "altitude_km": alt, | |
| } | |
| return labels | |
| def evaluate_baseline_at_cutoff(baseline_model, ground_truth: dict, cutoff: float) -> dict: | |
| """Evaluate baseline model. Uses altitude only, unaffected by staleness.""" | |
| altitudes = np.array([gt["altitude_km"] for gt in ground_truth.values()]) | |
| y_true = np.array([gt["risk_label"] for gt in ground_truth.values()]) | |
| risk_probs, _ = baseline_model.predict(altitudes) | |
| metrics = evaluate_risk(y_true, risk_probs) | |
| metrics["cutoff"] = cutoff | |
| metrics["n_events"] = len(y_true) | |
| return metrics | |
| def evaluate_xgboost_at_cutoff( | |
| xgboost_model, | |
| truncated_df: pd.DataFrame, | |
| ground_truth: dict, | |
| feature_cols: list[str], | |
| cutoff: float, | |
| ) -> dict: | |
| """Evaluate XGBoost on truncated CDM data.""" | |
| events = build_events(truncated_df, feature_cols) | |
| if len(events) == 0: | |
| return {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff} | |
| X, _, _ = events_to_flat_features(events) | |
| # Pad features if model was trained on augmented data with more columns | |
| expected_features = xgboost_model.scaler.n_features_in_ | |
| if X.shape[1] < expected_features: | |
| padding = np.zeros((X.shape[0], expected_features - X.shape[1]), dtype=X.dtype) | |
| X = np.hstack([X, padding]) | |
| event_ids = [e.event_id for e in events] | |
| valid_mask = np.array([eid in ground_truth for eid in event_ids]) | |
| X = X[valid_mask] | |
| valid_ids = [eid for eid in event_ids if eid in ground_truth] | |
| y_true = np.array([ground_truth[eid]["risk_label"] for eid in valid_ids]) | |
| if len(y_true) == 0 or y_true.sum() == 0: | |
| return {"auc_pr": 0.0, "f1": 0.0, "n_events": len(y_true), "cutoff": cutoff} | |
| # Pad features if model expects more (e.g., trained on augmented data) | |
| expected = xgboost_model.scaler.n_features_in_ | |
| if X.shape[1] < expected: | |
| pad_width = expected - X.shape[1] | |
| X = np.pad(X, ((0, 0), (0, pad_width)), constant_values=0) | |
| elif X.shape[1] > expected: | |
| X = X[:, :expected] | |
| risk_probs = xgboost_model.predict_risk(X) | |
| metrics = evaluate_risk(y_true, risk_probs) | |
| metrics["cutoff"] = cutoff | |
| metrics["n_events"] = len(y_true) | |
| return metrics | |
| def evaluate_pitft_at_cutoff( | |
| model, | |
| truncated_df: pd.DataFrame, | |
| ground_truth: dict, | |
| train_ds: CDMSequenceDataset, | |
| device: torch.device, | |
| temperature: float = 1.0, | |
| cutoff: float = 0.0, | |
| batch_size: int = 128, | |
| ) -> dict: | |
| """Evaluate PI-TFT on truncated CDM data with temperature scaling.""" | |
| # Ensure all required columns exist (pad missing with 0) | |
| df = truncated_df.copy() | |
| for col in train_ds.temporal_cols + train_ds.static_cols: | |
| if col not in df.columns: | |
| df[col] = 0.0 | |
| test_ds = CDMSequenceDataset( | |
| df, | |
| temporal_cols=train_ds.temporal_cols, | |
| static_cols=train_ds.static_cols, | |
| ) | |
| test_ds.set_normalization(train_ds) | |
| if len(test_ds) == 0: | |
| return {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff} | |
| # Get event IDs from the dataset | |
| event_ids = [e["event_id"] for e in test_ds.events] | |
| loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0) | |
| model.eval() | |
| all_probs = [] | |
| with torch.no_grad(): | |
| for batch in loader: | |
| temporal = batch["temporal"].to(device) | |
| static = batch["static"].to(device) | |
| tca = batch["time_to_tca"].to(device) | |
| mask = batch["mask"].to(device) | |
| risk_logit, _, _, _ = model(temporal, static, tca, mask) | |
| probs = torch.sigmoid(risk_logit / temperature).cpu().numpy().flatten() | |
| all_probs.append(probs) | |
| risk_probs = np.concatenate(all_probs) | |
| # Match predictions to ground truth | |
| valid_mask = np.array([eid in ground_truth for eid in event_ids]) | |
| risk_probs = risk_probs[valid_mask] | |
| valid_ids = [eid for eid in event_ids if eid in ground_truth] | |
| y_true = np.array([ground_truth[eid]["risk_label"] for eid in valid_ids]) | |
| if len(y_true) == 0 or y_true.sum() == 0: | |
| return {"auc_pr": 0.0, "f1": 0.0, "n_events": len(y_true), "cutoff": cutoff} | |
| metrics = evaluate_risk(y_true, risk_probs) | |
| metrics["cutoff"] = cutoff | |
| metrics["n_events"] = int(len(y_true)) | |
| return metrics | |
| def run_staleness_experiment( | |
| baseline_model, | |
| xgboost_model, | |
| pitft_model, | |
| pitft_checkpoint: dict, | |
| test_df: pd.DataFrame, | |
| train_ds: CDMSequenceDataset, | |
| feature_cols: list[str], | |
| device: torch.device, | |
| cutoffs: list[float] = None, | |
| quick: bool = False, | |
| ) -> dict: | |
| """Run the full staleness experiment across all cutoffs and models. | |
| Args: | |
| baseline_model: OrbitalShellBaseline instance | |
| xgboost_model: XGBoostConjunctionModel instance | |
| pitft_model: PhysicsInformedTFT (eval mode), or None to skip | |
| pitft_checkpoint: checkpoint dict with temperature | |
| test_df: ORIGINAL (untruncated) test DataFrame | |
| train_ds: CDMSequenceDataset from training data (for normalization) | |
| feature_cols: list of feature column names for XGBoost | |
| device: torch device | |
| cutoffs: list of staleness cutoffs (days before TCA) | |
| quick: if True, use fewer cutoffs | |
| """ | |
| if cutoffs is None: | |
| cutoffs = QUICK_CUTOFFS if quick else DEFAULT_CUTOFFS | |
| ground_truth = get_ground_truth_labels(test_df) | |
| n_pos = sum(1 for gt in ground_truth.values() if gt["risk_label"] == 1) | |
| print(f"\nGround truth: {len(ground_truth)} events, {n_pos} positive") | |
| temperature = 1.0 | |
| if pitft_checkpoint: | |
| temperature = pitft_checkpoint.get("temperature", 1.0) | |
| results = { | |
| "cutoffs": cutoffs, | |
| "n_test_events": len(ground_truth), | |
| "n_positive": n_pos, | |
| "baseline": [], | |
| "xgboost": [], | |
| "pitft": [], | |
| } | |
| for cutoff in cutoffs: | |
| print(f"\n{'='*50}") | |
| print(f"Staleness cutoff: {cutoff:.1f} days") | |
| print(f"{'='*50}") | |
| truncated = truncate_cdm_dataframe(test_df, cutoff) | |
| n_events = truncated["event_id"].nunique() | |
| n_rows = len(truncated) | |
| print(f" Surviving: {n_events} events, {n_rows} CDMs") | |
| # Baseline (uses altitude only — constant across cutoffs) | |
| bl = evaluate_baseline_at_cutoff(baseline_model, ground_truth, cutoff) | |
| results["baseline"].append(bl) | |
| print(f" Baseline AUC-PR={bl.get('auc_pr', 0):.4f}, F1={bl.get('f1', 0):.4f}") | |
| # XGBoost | |
| if n_events > 0: | |
| xgb = evaluate_xgboost_at_cutoff( | |
| xgboost_model, truncated, ground_truth, feature_cols, cutoff | |
| ) | |
| else: | |
| xgb = {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff} | |
| results["xgboost"].append(xgb) | |
| print(f" XGBoost AUC-PR={xgb.get('auc_pr', 0):.4f}, " | |
| f"F1={xgb.get('f1', 0):.4f} ({xgb.get('n_events', 0)} events)") | |
| # PI-TFT | |
| if n_events > 0 and pitft_model is not None: | |
| tft = evaluate_pitft_at_cutoff( | |
| pitft_model, truncated, ground_truth, train_ds, | |
| device, temperature=temperature, cutoff=cutoff, | |
| ) | |
| else: | |
| tft = {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff} | |
| results["pitft"].append(tft) | |
| print(f" PI-TFT AUC-PR={tft.get('auc_pr', 0):.4f}, " | |
| f"F1={tft.get('f1', 0):.4f}") | |
| return results | |