# Generated by Claude Code -- 2026-02-08 """Build padded CDM sequences for the Temporal Fusion Transformer. Each conjunction event is a variable-length time series of CDM snapshots. This module handles: - Selecting temporal vs static features - Padding/truncating to fixed length - Creating attention masks for padded positions - Train/val/test splitting with stratification """ import numpy as np import pandas as pd import torch from torch.utils.data import Dataset from sklearn.model_selection import train_test_split from pathlib import Path # Maximum CDM sequence length (95th percentile of real data is ~25) MAX_SEQ_LEN = 30 # Features that change with each CDM update (time-varying) TEMPORAL_FEATURES = [ "miss_distance", "relative_speed", "relative_position_r", "relative_position_t", "relative_position_n", "relative_velocity_r", "relative_velocity_t", "relative_velocity_n", "max_risk_estimate", "max_risk_scaling", # Target object covariance "t_sigma_r", "t_sigma_t", "t_sigma_n", "t_sigma_rdot", "t_sigma_tdot", "t_sigma_ndot", # Chaser object covariance "c_sigma_r", "c_sigma_t", "c_sigma_n", "c_sigma_rdot", "c_sigma_tdot", "c_sigma_ndot", ] # Features that are constant per event (object properties) STATIC_FEATURES = [ "t_h_apo", "t_h_per", "t_j2k_sma", "t_j2k_inc", "t_ecc", "c_h_apo", "c_h_per", "c_j2k_sma", "c_j2k_inc", "c_ecc", "t_span", "c_span", ] # Orbital density features from CRASH Clock analysis (added by OrbitalDensityComputer) DENSITY_FEATURES = [ "shell_density", "shell_collision_rate", "local_crash_clock_log", "altitude_percentile", "n_events_in_shell", "shell_risk_rate", ] def find_available_features(df: pd.DataFrame, candidates: list[str]) -> list[str]: """Filter feature list to only columns that exist in the DataFrame.""" available = [c for c in candidates if c in df.columns] missing = [c for c in candidates if c not in df.columns] if missing: print(f" Note: {len(missing)} features not in dataset, using {len(available)}") return available class CDMSequenceDataset(Dataset): """ PyTorch Dataset that serves padded CDM sequences for the Transformer. Each item contains: - temporal_features: (S, F_t) tensor of time-varying CDM features - static_features: (F_s,) tensor of object properties - time_to_tca: (S, 1) tensor of time-to-closest-approach values - mask: (S,) boolean mask (True = real data, False = padding) - risk_label: scalar binary target - miss_distance_log: scalar log1p(final_miss_distance) target """ def __init__( self, df: pd.DataFrame, max_seq_len: int = MAX_SEQ_LEN, temporal_cols: list[str] = None, static_cols: list[str] = None, ): self.max_seq_len = max_seq_len # Find available features self.temporal_cols = temporal_cols or find_available_features(df, TEMPORAL_FEATURES) self.static_cols = static_cols or find_available_features(df, STATIC_FEATURES) print(f" Temporal features: {len(self.temporal_cols)}") print(f" Static features: {len(self.static_cols)}") # Group by event_id self.events = [] for event_id, group in df.groupby("event_id"): # Sort by time_to_tca descending (first CDM = furthest from TCA) group = group.sort_values("time_to_tca", ascending=False) # Track data source for domain weighting source = "kelvins" if "source" in group.columns: source = group["source"].iloc[0] self.events.append({ "event_id": event_id, "group": group, "source": source, }) # Compute global normalization stats from training data self.temporal_mean = df[self.temporal_cols].mean().values.astype(np.float32) self.temporal_std = df[self.temporal_cols].std().values.astype(np.float32) self.temporal_std[self.temporal_std < 1e-8] = 1.0 # avoid div by zero self.static_mean = df[self.static_cols].mean().values.astype(np.float32) self.static_std = df[self.static_cols].std().values.astype(np.float32) self.static_std[self.static_std < 1e-8] = 1.0 # Normalize time_to_tca self.tca_mean = float(df["time_to_tca"].mean()) self.tca_std = float(df["time_to_tca"].std()) if self.tca_std < 1e-8: self.tca_std = 1.0 # Compute delta normalization stats (approx from per-step differences) # Deltas have different magnitude than raw features, need separate stats self._compute_delta_stats(df) def _compute_delta_stats(self, df: pd.DataFrame): """Estimate normalization stats for temporal first-order differences.""" # Sample a subset of events to estimate delta distributions delta_samples = [] for _, group in df.groupby("event_id"): if len(group) < 2: continue vals = group[self.temporal_cols].values.astype(np.float32) vals = np.nan_to_num(vals, nan=0.0, posinf=0.0, neginf=0.0) deltas = np.diff(vals, axis=0) delta_samples.append(deltas) if len(delta_samples) >= 2000: # cap for speed break if delta_samples: all_deltas = np.concatenate(delta_samples, axis=0) self.delta_mean = all_deltas.mean(axis=0).astype(np.float32) self.delta_std = all_deltas.std(axis=0).astype(np.float32) self.delta_std[self.delta_std < 1e-8] = 1.0 else: n = len(self.temporal_cols) self.delta_mean = np.zeros(n, dtype=np.float32) self.delta_std = np.ones(n, dtype=np.float32) def set_normalization(self, other: "CDMSequenceDataset"): """Copy normalization stats from another dataset (e.g., training set).""" self.temporal_mean = other.temporal_mean self.temporal_std = other.temporal_std self.static_mean = other.static_mean self.static_std = other.static_std self.tca_mean = other.tca_mean self.tca_std = other.tca_std self.delta_mean = other.delta_mean self.delta_std = other.delta_std def __len__(self): return len(self.events) def __getitem__(self, idx): event = self.events[idx] group = event["group"] # Extract temporal features: (seq_len, n_temporal) temporal = group[self.temporal_cols].values.astype(np.float32) temporal = np.nan_to_num(temporal, nan=0.0, posinf=0.0, neginf=0.0) # Compute first-order differences (deltas) for temporal features # This captures trends: is miss_distance shrinking? Is covariance tightening? if len(temporal) > 1: deltas = np.diff(temporal, axis=0) # (seq_len-1, n_temporal) # Prepend zeros for the first timestep (no prior to diff against) deltas = np.concatenate([np.zeros((1, deltas.shape[1]), dtype=np.float32), deltas], axis=0) else: deltas = np.zeros_like(temporal) # Normalize raw features and deltas separately temporal = (temporal - self.temporal_mean) / self.temporal_std deltas = (deltas - self.delta_mean) / self.delta_std # Concatenate: (seq_len, n_temporal * 2) temporal = np.concatenate([temporal, deltas], axis=1) # Extract static features from last row (they're constant per event) static = group[self.static_cols].iloc[-1].values.astype(np.float32) static = np.nan_to_num(static, nan=0.0, posinf=0.0, neginf=0.0) # Time-to-TCA values: (seq_len, 1) tca = group["time_to_tca"].values.astype(np.float32).reshape(-1, 1) # Normalize static = (static - self.static_mean) / self.static_std tca = (tca - self.tca_mean) / self.tca_std # Truncate or pad to max_seq_len seq_len = len(temporal) if seq_len > self.max_seq_len: # Keep the most recent CDMs (closest to TCA = most informative) temporal = temporal[-self.max_seq_len:] tca = tca[-self.max_seq_len:] seq_len = self.max_seq_len # Pad (left-pad so the most recent CDM is always at position -1) pad_len = self.max_seq_len - seq_len if pad_len > 0: temporal = np.pad(temporal, ((pad_len, 0), (0, 0)), constant_values=0) tca = np.pad(tca, ((pad_len, 0), (0, 0)), constant_values=0) # Attention mask: True for real positions, False for padding mask = np.zeros(self.max_seq_len, dtype=bool) mask[pad_len:] = True # Target: risk label from final CDM's risk column # risk > -5 means collision probability > 1e-5 (high risk) final_risk = group["risk"].iloc[-1] risk_label = 1.0 if final_risk > -5 else 0.0 # Target: log1p of final miss distance final_miss = group["miss_distance"].iloc[-1] if "miss_distance" in group.columns else 0.0 miss_log = np.log1p(max(final_miss, 0.0)) # Target: log10(Pc) — the Kelvins `risk` column is already log10(Pc). # Clamp to [-20, 0] (Pc ranges from ~1e-20 to ~1) pc_log10 = float(max(min(final_risk, 0.0), -20.0)) # Domain weight: Kelvins events get full weight, Space-Track events # get reduced weight since they have sparse features (16 vs 103 columns). # This prevents the model from learning shortcuts on zero-padded features. source = event.get("source", "kelvins") domain_weight = 1.0 if source == "kelvins" else 0.3 return { "temporal": torch.tensor(temporal, dtype=torch.float32), "static": torch.tensor(static, dtype=torch.float32), "time_to_tca": torch.tensor(tca, dtype=torch.float32), "mask": torch.tensor(mask, dtype=torch.bool), "risk_label": torch.tensor(risk_label, dtype=torch.float32), "miss_log": torch.tensor(miss_log, dtype=torch.float32), "pc_log10": torch.tensor(pc_log10, dtype=torch.float32), "domain_weight": torch.tensor(domain_weight, dtype=torch.float32), } class PretrainDataset(Dataset): """Simplified CDM dataset for self-supervised pre-training (no labels needed). Returns only temporal features, static features, time_to_tca, and mask. Can process combined train+test data since labels aren't used. """ def __init__( self, df: pd.DataFrame, max_seq_len: int = MAX_SEQ_LEN, temporal_cols: list[str] = None, static_cols: list[str] = None, ): self.max_seq_len = max_seq_len self.temporal_cols = temporal_cols or find_available_features(df, TEMPORAL_FEATURES) self.static_cols = static_cols or find_available_features(df, STATIC_FEATURES) print(f" PretrainDataset — Temporal: {len(self.temporal_cols)}, Static: {len(self.static_cols)}") # Group by event_id self.events = [] for event_id, group in df.groupby("event_id"): group = group.sort_values("time_to_tca", ascending=False) self.events.append({"event_id": event_id, "group": group}) # Compute global normalization stats self.temporal_mean = df[self.temporal_cols].mean().values.astype(np.float32) self.temporal_std = df[self.temporal_cols].std().values.astype(np.float32) self.temporal_std[self.temporal_std < 1e-8] = 1.0 self.static_mean = df[self.static_cols].mean().values.astype(np.float32) self.static_std = df[self.static_cols].std().values.astype(np.float32) self.static_std[self.static_std < 1e-8] = 1.0 self.tca_mean = float(df["time_to_tca"].mean()) self.tca_std = float(df["time_to_tca"].std()) if self.tca_std < 1e-8: self.tca_std = 1.0 self._compute_delta_stats(df) def _compute_delta_stats(self, df: pd.DataFrame): """Estimate normalization stats for temporal first-order differences.""" delta_samples = [] for _, group in df.groupby("event_id"): if len(group) < 2: continue vals = group[self.temporal_cols].values.astype(np.float32) vals = np.nan_to_num(vals, nan=0.0, posinf=0.0, neginf=0.0) deltas = np.diff(vals, axis=0) delta_samples.append(deltas) if len(delta_samples) >= 2000: break if delta_samples: all_deltas = np.concatenate(delta_samples, axis=0) self.delta_mean = all_deltas.mean(axis=0).astype(np.float32) self.delta_std = all_deltas.std(axis=0).astype(np.float32) self.delta_std[self.delta_std < 1e-8] = 1.0 else: n = len(self.temporal_cols) self.delta_mean = np.zeros(n, dtype=np.float32) self.delta_std = np.ones(n, dtype=np.float32) def set_normalization(self, other): """Copy normalization stats from another dataset.""" self.temporal_mean = other.temporal_mean self.temporal_std = other.temporal_std self.static_mean = other.static_mean self.static_std = other.static_std self.tca_mean = other.tca_mean self.tca_std = other.tca_std self.delta_mean = other.delta_mean self.delta_std = other.delta_std def __len__(self): return len(self.events) def __getitem__(self, idx): event = self.events[idx] group = event["group"] # Extract temporal features temporal = group[self.temporal_cols].values.astype(np.float32) temporal = np.nan_to_num(temporal, nan=0.0, posinf=0.0, neginf=0.0) # Compute first-order differences if len(temporal) > 1: deltas = np.diff(temporal, axis=0) deltas = np.concatenate([np.zeros((1, deltas.shape[1]), dtype=np.float32), deltas], axis=0) else: deltas = np.zeros_like(temporal) # Normalize temporal = (temporal - self.temporal_mean) / self.temporal_std deltas = (deltas - self.delta_mean) / self.delta_std temporal = np.concatenate([temporal, deltas], axis=1) # Static features static = group[self.static_cols].iloc[-1].values.astype(np.float32) static = np.nan_to_num(static, nan=0.0, posinf=0.0, neginf=0.0) # Time-to-TCA tca = group["time_to_tca"].values.astype(np.float32).reshape(-1, 1) static = (static - self.static_mean) / self.static_std tca = (tca - self.tca_mean) / self.tca_std # Truncate or pad seq_len = len(temporal) if seq_len > self.max_seq_len: temporal = temporal[-self.max_seq_len:] tca = tca[-self.max_seq_len:] seq_len = self.max_seq_len pad_len = self.max_seq_len - seq_len if pad_len > 0: temporal = np.pad(temporal, ((pad_len, 0), (0, 0)), constant_values=0) tca = np.pad(tca, ((pad_len, 0), (0, 0)), constant_values=0) mask = np.zeros(self.max_seq_len, dtype=bool) mask[pad_len:] = True return { "temporal": torch.tensor(temporal, dtype=torch.float32), "static": torch.tensor(static, dtype=torch.float32), "time_to_tca": torch.tensor(tca, dtype=torch.float32), "mask": torch.tensor(mask, dtype=torch.bool), } def build_datasets( train_df: pd.DataFrame, test_df: pd.DataFrame, val_fraction: float = 0.1, use_density: bool = False, cal_fraction: float = 0.0, ) -> tuple: """ Build train, validation, and test datasets with shared normalization. Splits training data into train + val by event_id (stratified by risk). Args: train_df: Training CDM DataFrame test_df: Test CDM DataFrame val_fraction: Fraction of Kelvins training events for validation use_density: If True, include DENSITY_FEATURES in static features cal_fraction: If > 0, further split validation into val + calibration for conformal prediction. Returns 4-tuple instead of 3. Returns: If cal_fraction == 0: (train_ds, val_ds, test_ds) If cal_fraction > 0: (train_ds, val_ds, cal_ds, test_ds) """ # Compute density features if requested if use_density: from src.data.density_features import OrbitalDensityComputer density_computer = OrbitalDensityComputer() density_computer.fit(train_df) train_df = density_computer.transform(train_df) test_df = density_computer.transform(test_df) else: density_computer = None # Static columns: base (filtered to available) + optional density static_cols = [c for c in STATIC_FEATURES if c in train_df.columns] if use_density: static_cols = static_cols + [ f for f in DENSITY_FEATURES if f in train_df.columns ] # Determine risk label per event for stratification has_source = "source" in train_df.columns agg_dict = {"risk": ("risk", "last")} if has_source: agg_dict["source"] = ("source", "first") event_meta = train_df.groupby("event_id").agg(**agg_dict).reset_index() event_meta["label"] = (event_meta["risk"] > -5).astype(int) # Split validation from KELVINS-ONLY events for fair model selection. # Space-Track events (sparse features, all high-risk) inflate val metrics. if has_source: kelvins_events = event_meta[event_meta["source"] == "kelvins"] other_events = event_meta[event_meta["source"] != "kelvins"] kelvins_ids = kelvins_events["event_id"].values kelvins_labels = kelvins_events["label"].values # Stratified split on Kelvins events only k_train_ids, val_ids = train_test_split( kelvins_ids, test_size=val_fraction, stratify=kelvins_labels, random_state=42 ) # Training = Kelvins train split + all Space-Track events train_ids = np.concatenate([k_train_ids, other_events["event_id"].values]) else: event_ids = event_meta["event_id"].values labels = event_meta["label"].values train_ids, val_ids = train_test_split( event_ids, test_size=val_fraction, stratify=labels, random_state=42 ) # Further split validation into val + calibration for conformal prediction cal_ids = np.array([]) if cal_fraction > 0 and len(val_ids) > 20: val_labels = event_meta[event_meta["event_id"].isin(val_ids)]["label"].values val_ids_arr = val_ids val_ids, cal_ids = train_test_split( val_ids_arr, test_size=cal_fraction, stratify=val_labels, random_state=123, # different seed from train/val split ) train_sub = train_df[train_df["event_id"].isin(train_ids)] val_sub = train_df[train_df["event_id"].isin(val_ids)] print(f"Building datasets:") print(f" Train events: {len(train_ids)}") if has_source: n_k = train_sub[train_sub["source"] == "kelvins"]["event_id"].nunique() n_s = train_sub[train_sub["source"] != "kelvins"]["event_id"].nunique() print(f" (Kelvins: {n_k}, Space-Track: {n_s})") if use_density: print(f" Static features: {len(static_cols)} (base: {len(STATIC_FEATURES)}, " f"density: {len(static_cols) - len(STATIC_FEATURES)})") train_ds = CDMSequenceDataset(train_sub, static_cols=static_cols) print(f" Val events: {len(val_ids)} (Kelvins-only)") val_ds = CDMSequenceDataset(val_sub, static_cols=static_cols) val_ds.set_normalization(train_ds) # use training stats print(f" Test events: {test_df['event_id'].nunique()}") test_ds = CDMSequenceDataset(test_df, temporal_cols=train_ds.temporal_cols, static_cols=static_cols) test_ds.set_normalization(train_ds) # Store density computer on train_ds for checkpoint saving if density_computer is not None: train_ds._density_computer = density_computer if cal_fraction > 0 and len(cal_ids) > 0: cal_sub = train_df[train_df["event_id"].isin(cal_ids)] print(f" Cal events: {len(cal_ids)} (for conformal prediction)") cal_ds = CDMSequenceDataset(cal_sub, static_cols=static_cols) cal_ds.set_normalization(train_ds) return train_ds, val_ds, cal_ds, test_ds return train_ds, val_ds, test_ds