panacea-api / src /evaluation /staleness.py
DTanzillo's picture
Upload folder using huggingface_hub
a4b5ecb verified
# Generated by Claude Code -- 2026-02-13
"""TLE Staleness Sensitivity Experiment.
Evaluates how model performance degrades as CDM data becomes stale.
Simulates staleness by filtering CDM sequences to only include updates
received at least `cutoff_days` before TCA.
The Kelvins test set has time_to_tca in [2.0, 7.0] days, so meaningful
cutoffs are in that range. A cutoff of 2.0 keeps all data (baseline),
while a cutoff of 6.0 keeps only the earliest CDMs.
Ground-truth labels always come from the ORIGINAL (untruncated) test set —
we're measuring how well models predict with less-recent information.
"""
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from src.data.cdm_loader import build_events, events_to_flat_features, get_feature_columns
from src.data.sequence_builder import CDMSequenceDataset
from src.evaluation.metrics import evaluate_risk
# Staleness cutoffs (days before TCA)
# 2.0 = keep all data (baseline), 6.0 = only very early CDMs
DEFAULT_CUTOFFS = [2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 6.0]
QUICK_CUTOFFS = [2.0, 4.0, 6.0]
def truncate_cdm_dataframe(df: pd.DataFrame, cutoff_days: float) -> pd.DataFrame:
"""Filter CDM rows to only those with time_to_tca >= cutoff_days.
Simulates data staleness: if cutoff=4.0, the model only sees CDMs
that arrived 4+ days before closest approach.
"""
return df[df["time_to_tca"] >= cutoff_days].copy()
def get_ground_truth_labels(df: pd.DataFrame) -> dict:
"""Extract per-event ground truth labels from the FULL (untruncated) dataset.
Labels come from the final CDM per event (closest to TCA).
Returns: {event_id: {"risk_label": int, "miss_log": float, "altitude_km": float}}
"""
labels = {}
for event_id, group in df.groupby("event_id"):
group = group.sort_values("time_to_tca", ascending=True)
final = group.iloc[0]
risk_label = 1 if final["risk"] > -5 else 0
miss_log = float(np.log1p(max(final.get("miss_distance", 0.0), 0.0)))
alt = float(final.get("t_h_apo", 0.0))
labels[int(event_id)] = {
"risk_label": risk_label,
"miss_log": miss_log,
"altitude_km": alt,
}
return labels
def evaluate_baseline_at_cutoff(baseline_model, ground_truth: dict, cutoff: float) -> dict:
"""Evaluate baseline model. Uses altitude only, unaffected by staleness."""
altitudes = np.array([gt["altitude_km"] for gt in ground_truth.values()])
y_true = np.array([gt["risk_label"] for gt in ground_truth.values()])
risk_probs, _ = baseline_model.predict(altitudes)
metrics = evaluate_risk(y_true, risk_probs)
metrics["cutoff"] = cutoff
metrics["n_events"] = len(y_true)
return metrics
def evaluate_xgboost_at_cutoff(
xgboost_model,
truncated_df: pd.DataFrame,
ground_truth: dict,
feature_cols: list[str],
cutoff: float,
) -> dict:
"""Evaluate XGBoost on truncated CDM data."""
events = build_events(truncated_df, feature_cols)
if len(events) == 0:
return {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff}
X, _, _ = events_to_flat_features(events)
# Pad features if model was trained on augmented data with more columns
expected_features = xgboost_model.scaler.n_features_in_
if X.shape[1] < expected_features:
padding = np.zeros((X.shape[0], expected_features - X.shape[1]), dtype=X.dtype)
X = np.hstack([X, padding])
event_ids = [e.event_id for e in events]
valid_mask = np.array([eid in ground_truth for eid in event_ids])
X = X[valid_mask]
valid_ids = [eid for eid in event_ids if eid in ground_truth]
y_true = np.array([ground_truth[eid]["risk_label"] for eid in valid_ids])
if len(y_true) == 0 or y_true.sum() == 0:
return {"auc_pr": 0.0, "f1": 0.0, "n_events": len(y_true), "cutoff": cutoff}
# Pad features if model expects more (e.g., trained on augmented data)
expected = xgboost_model.scaler.n_features_in_
if X.shape[1] < expected:
pad_width = expected - X.shape[1]
X = np.pad(X, ((0, 0), (0, pad_width)), constant_values=0)
elif X.shape[1] > expected:
X = X[:, :expected]
risk_probs = xgboost_model.predict_risk(X)
metrics = evaluate_risk(y_true, risk_probs)
metrics["cutoff"] = cutoff
metrics["n_events"] = len(y_true)
return metrics
def evaluate_pitft_at_cutoff(
model,
truncated_df: pd.DataFrame,
ground_truth: dict,
train_ds: CDMSequenceDataset,
device: torch.device,
temperature: float = 1.0,
cutoff: float = 0.0,
batch_size: int = 128,
) -> dict:
"""Evaluate PI-TFT on truncated CDM data with temperature scaling."""
# Ensure all required columns exist (pad missing with 0)
df = truncated_df.copy()
for col in train_ds.temporal_cols + train_ds.static_cols:
if col not in df.columns:
df[col] = 0.0
test_ds = CDMSequenceDataset(
df,
temporal_cols=train_ds.temporal_cols,
static_cols=train_ds.static_cols,
)
test_ds.set_normalization(train_ds)
if len(test_ds) == 0:
return {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff}
# Get event IDs from the dataset
event_ids = [e["event_id"] for e in test_ds.events]
loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0)
model.eval()
all_probs = []
with torch.no_grad():
for batch in loader:
temporal = batch["temporal"].to(device)
static = batch["static"].to(device)
tca = batch["time_to_tca"].to(device)
mask = batch["mask"].to(device)
risk_logit, _, _, _ = model(temporal, static, tca, mask)
probs = torch.sigmoid(risk_logit / temperature).cpu().numpy().flatten()
all_probs.append(probs)
risk_probs = np.concatenate(all_probs)
# Match predictions to ground truth
valid_mask = np.array([eid in ground_truth for eid in event_ids])
risk_probs = risk_probs[valid_mask]
valid_ids = [eid for eid in event_ids if eid in ground_truth]
y_true = np.array([ground_truth[eid]["risk_label"] for eid in valid_ids])
if len(y_true) == 0 or y_true.sum() == 0:
return {"auc_pr": 0.0, "f1": 0.0, "n_events": len(y_true), "cutoff": cutoff}
metrics = evaluate_risk(y_true, risk_probs)
metrics["cutoff"] = cutoff
metrics["n_events"] = int(len(y_true))
return metrics
def run_staleness_experiment(
baseline_model,
xgboost_model,
pitft_model,
pitft_checkpoint: dict,
test_df: pd.DataFrame,
train_ds: CDMSequenceDataset,
feature_cols: list[str],
device: torch.device,
cutoffs: list[float] = None,
quick: bool = False,
) -> dict:
"""Run the full staleness experiment across all cutoffs and models.
Args:
baseline_model: OrbitalShellBaseline instance
xgboost_model: XGBoostConjunctionModel instance
pitft_model: PhysicsInformedTFT (eval mode), or None to skip
pitft_checkpoint: checkpoint dict with temperature
test_df: ORIGINAL (untruncated) test DataFrame
train_ds: CDMSequenceDataset from training data (for normalization)
feature_cols: list of feature column names for XGBoost
device: torch device
cutoffs: list of staleness cutoffs (days before TCA)
quick: if True, use fewer cutoffs
"""
if cutoffs is None:
cutoffs = QUICK_CUTOFFS if quick else DEFAULT_CUTOFFS
ground_truth = get_ground_truth_labels(test_df)
n_pos = sum(1 for gt in ground_truth.values() if gt["risk_label"] == 1)
print(f"\nGround truth: {len(ground_truth)} events, {n_pos} positive")
temperature = 1.0
if pitft_checkpoint:
temperature = pitft_checkpoint.get("temperature", 1.0)
results = {
"cutoffs": cutoffs,
"n_test_events": len(ground_truth),
"n_positive": n_pos,
"baseline": [],
"xgboost": [],
"pitft": [],
}
for cutoff in cutoffs:
print(f"\n{'='*50}")
print(f"Staleness cutoff: {cutoff:.1f} days")
print(f"{'='*50}")
truncated = truncate_cdm_dataframe(test_df, cutoff)
n_events = truncated["event_id"].nunique()
n_rows = len(truncated)
print(f" Surviving: {n_events} events, {n_rows} CDMs")
# Baseline (uses altitude only — constant across cutoffs)
bl = evaluate_baseline_at_cutoff(baseline_model, ground_truth, cutoff)
results["baseline"].append(bl)
print(f" Baseline AUC-PR={bl.get('auc_pr', 0):.4f}, F1={bl.get('f1', 0):.4f}")
# XGBoost
if n_events > 0:
xgb = evaluate_xgboost_at_cutoff(
xgboost_model, truncated, ground_truth, feature_cols, cutoff
)
else:
xgb = {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff}
results["xgboost"].append(xgb)
print(f" XGBoost AUC-PR={xgb.get('auc_pr', 0):.4f}, "
f"F1={xgb.get('f1', 0):.4f} ({xgb.get('n_events', 0)} events)")
# PI-TFT
if n_events > 0 and pitft_model is not None:
tft = evaluate_pitft_at_cutoff(
pitft_model, truncated, ground_truth, train_ds,
device, temperature=temperature, cutoff=cutoff,
)
else:
tft = {"auc_pr": 0.0, "f1": 0.0, "n_events": 0, "cutoff": cutoff}
results["pitft"].append(tft)
print(f" PI-TFT AUC-PR={tft.get('auc_pr', 0):.4f}, "
f"F1={tft.get('f1', 0):.4f}")
return results