Spaces:
Sleeping
Sleeping
| # Generated by Claude Code -- 2026-02-08 | |
| """Build padded CDM sequences for the Temporal Fusion Transformer. | |
| Each conjunction event is a variable-length time series of CDM snapshots. | |
| This module handles: | |
| - Selecting temporal vs static features | |
| - Padding/truncating to fixed length | |
| - Creating attention masks for padded positions | |
| - Train/val/test splitting with stratification | |
| """ | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from torch.utils.data import Dataset | |
| from sklearn.model_selection import train_test_split | |
| from pathlib import Path | |
| # Maximum CDM sequence length (95th percentile of real data is ~25) | |
| MAX_SEQ_LEN = 30 | |
| # Features that change with each CDM update (time-varying) | |
| TEMPORAL_FEATURES = [ | |
| "miss_distance", | |
| "relative_speed", | |
| "relative_position_r", "relative_position_t", "relative_position_n", | |
| "relative_velocity_r", "relative_velocity_t", "relative_velocity_n", | |
| "max_risk_estimate", "max_risk_scaling", | |
| # Target object covariance | |
| "t_sigma_r", "t_sigma_t", "t_sigma_n", | |
| "t_sigma_rdot", "t_sigma_tdot", "t_sigma_ndot", | |
| # Chaser object covariance | |
| "c_sigma_r", "c_sigma_t", "c_sigma_n", | |
| "c_sigma_rdot", "c_sigma_tdot", "c_sigma_ndot", | |
| ] | |
| # Features that are constant per event (object properties) | |
| STATIC_FEATURES = [ | |
| "t_h_apo", "t_h_per", "t_j2k_sma", "t_j2k_inc", "t_ecc", | |
| "c_h_apo", "c_h_per", "c_j2k_sma", "c_j2k_inc", "c_ecc", | |
| "t_span", "c_span", | |
| ] | |
| # Orbital density features from CRASH Clock analysis (added by OrbitalDensityComputer) | |
| DENSITY_FEATURES = [ | |
| "shell_density", | |
| "shell_collision_rate", | |
| "local_crash_clock_log", | |
| "altitude_percentile", | |
| "n_events_in_shell", | |
| "shell_risk_rate", | |
| ] | |
| def find_available_features(df: pd.DataFrame, candidates: list[str]) -> list[str]: | |
| """Filter feature list to only columns that exist in the DataFrame.""" | |
| available = [c for c in candidates if c in df.columns] | |
| missing = [c for c in candidates if c not in df.columns] | |
| if missing: | |
| print(f" Note: {len(missing)} features not in dataset, using {len(available)}") | |
| return available | |
| class CDMSequenceDataset(Dataset): | |
| """ | |
| PyTorch Dataset that serves padded CDM sequences for the Transformer. | |
| Each item contains: | |
| - temporal_features: (S, F_t) tensor of time-varying CDM features | |
| - static_features: (F_s,) tensor of object properties | |
| - time_to_tca: (S, 1) tensor of time-to-closest-approach values | |
| - mask: (S,) boolean mask (True = real data, False = padding) | |
| - risk_label: scalar binary target | |
| - miss_distance_log: scalar log1p(final_miss_distance) target | |
| """ | |
| def __init__( | |
| self, | |
| df: pd.DataFrame, | |
| max_seq_len: int = MAX_SEQ_LEN, | |
| temporal_cols: list[str] = None, | |
| static_cols: list[str] = None, | |
| ): | |
| self.max_seq_len = max_seq_len | |
| # Find available features | |
| self.temporal_cols = temporal_cols or find_available_features(df, TEMPORAL_FEATURES) | |
| self.static_cols = static_cols or find_available_features(df, STATIC_FEATURES) | |
| print(f" Temporal features: {len(self.temporal_cols)}") | |
| print(f" Static features: {len(self.static_cols)}") | |
| # Group by event_id | |
| self.events = [] | |
| for event_id, group in df.groupby("event_id"): | |
| # Sort by time_to_tca descending (first CDM = furthest from TCA) | |
| group = group.sort_values("time_to_tca", ascending=False) | |
| # Track data source for domain weighting | |
| source = "kelvins" | |
| if "source" in group.columns: | |
| source = group["source"].iloc[0] | |
| self.events.append({ | |
| "event_id": event_id, | |
| "group": group, | |
| "source": source, | |
| }) | |
| # Compute global normalization stats from training data | |
| self.temporal_mean = df[self.temporal_cols].mean().values.astype(np.float32) | |
| self.temporal_std = df[self.temporal_cols].std().values.astype(np.float32) | |
| self.temporal_std[self.temporal_std < 1e-8] = 1.0 # avoid div by zero | |
| self.static_mean = df[self.static_cols].mean().values.astype(np.float32) | |
| self.static_std = df[self.static_cols].std().values.astype(np.float32) | |
| self.static_std[self.static_std < 1e-8] = 1.0 | |
| # Normalize time_to_tca | |
| self.tca_mean = float(df["time_to_tca"].mean()) | |
| self.tca_std = float(df["time_to_tca"].std()) | |
| if self.tca_std < 1e-8: | |
| self.tca_std = 1.0 | |
| # Compute delta normalization stats (approx from per-step differences) | |
| # Deltas have different magnitude than raw features, need separate stats | |
| self._compute_delta_stats(df) | |
| def _compute_delta_stats(self, df: pd.DataFrame): | |
| """Estimate normalization stats for temporal first-order differences.""" | |
| # Sample a subset of events to estimate delta distributions | |
| delta_samples = [] | |
| for _, group in df.groupby("event_id"): | |
| if len(group) < 2: | |
| continue | |
| vals = group[self.temporal_cols].values.astype(np.float32) | |
| vals = np.nan_to_num(vals, nan=0.0, posinf=0.0, neginf=0.0) | |
| deltas = np.diff(vals, axis=0) | |
| delta_samples.append(deltas) | |
| if len(delta_samples) >= 2000: # cap for speed | |
| break | |
| if delta_samples: | |
| all_deltas = np.concatenate(delta_samples, axis=0) | |
| self.delta_mean = all_deltas.mean(axis=0).astype(np.float32) | |
| self.delta_std = all_deltas.std(axis=0).astype(np.float32) | |
| self.delta_std[self.delta_std < 1e-8] = 1.0 | |
| else: | |
| n = len(self.temporal_cols) | |
| self.delta_mean = np.zeros(n, dtype=np.float32) | |
| self.delta_std = np.ones(n, dtype=np.float32) | |
| def set_normalization(self, other: "CDMSequenceDataset"): | |
| """Copy normalization stats from another dataset (e.g., training set).""" | |
| self.temporal_mean = other.temporal_mean | |
| self.temporal_std = other.temporal_std | |
| self.static_mean = other.static_mean | |
| self.static_std = other.static_std | |
| self.tca_mean = other.tca_mean | |
| self.tca_std = other.tca_std | |
| self.delta_mean = other.delta_mean | |
| self.delta_std = other.delta_std | |
| def __len__(self): | |
| return len(self.events) | |
| def __getitem__(self, idx): | |
| event = self.events[idx] | |
| group = event["group"] | |
| # Extract temporal features: (seq_len, n_temporal) | |
| temporal = group[self.temporal_cols].values.astype(np.float32) | |
| temporal = np.nan_to_num(temporal, nan=0.0, posinf=0.0, neginf=0.0) | |
| # Compute first-order differences (deltas) for temporal features | |
| # This captures trends: is miss_distance shrinking? Is covariance tightening? | |
| if len(temporal) > 1: | |
| deltas = np.diff(temporal, axis=0) # (seq_len-1, n_temporal) | |
| # Prepend zeros for the first timestep (no prior to diff against) | |
| deltas = np.concatenate([np.zeros((1, deltas.shape[1]), dtype=np.float32), deltas], axis=0) | |
| else: | |
| deltas = np.zeros_like(temporal) | |
| # Normalize raw features and deltas separately | |
| temporal = (temporal - self.temporal_mean) / self.temporal_std | |
| deltas = (deltas - self.delta_mean) / self.delta_std | |
| # Concatenate: (seq_len, n_temporal * 2) | |
| temporal = np.concatenate([temporal, deltas], axis=1) | |
| # Extract static features from last row (they're constant per event) | |
| static = group[self.static_cols].iloc[-1].values.astype(np.float32) | |
| static = np.nan_to_num(static, nan=0.0, posinf=0.0, neginf=0.0) | |
| # Time-to-TCA values: (seq_len, 1) | |
| tca = group["time_to_tca"].values.astype(np.float32).reshape(-1, 1) | |
| # Normalize | |
| static = (static - self.static_mean) / self.static_std | |
| tca = (tca - self.tca_mean) / self.tca_std | |
| # Truncate or pad to max_seq_len | |
| seq_len = len(temporal) | |
| if seq_len > self.max_seq_len: | |
| # Keep the most recent CDMs (closest to TCA = most informative) | |
| temporal = temporal[-self.max_seq_len:] | |
| tca = tca[-self.max_seq_len:] | |
| seq_len = self.max_seq_len | |
| # Pad (left-pad so the most recent CDM is always at position -1) | |
| pad_len = self.max_seq_len - seq_len | |
| if pad_len > 0: | |
| temporal = np.pad(temporal, ((pad_len, 0), (0, 0)), constant_values=0) | |
| tca = np.pad(tca, ((pad_len, 0), (0, 0)), constant_values=0) | |
| # Attention mask: True for real positions, False for padding | |
| mask = np.zeros(self.max_seq_len, dtype=bool) | |
| mask[pad_len:] = True | |
| # Target: risk label from final CDM's risk column | |
| # risk > -5 means collision probability > 1e-5 (high risk) | |
| final_risk = group["risk"].iloc[-1] | |
| risk_label = 1.0 if final_risk > -5 else 0.0 | |
| # Target: log1p of final miss distance | |
| final_miss = group["miss_distance"].iloc[-1] if "miss_distance" in group.columns else 0.0 | |
| miss_log = np.log1p(max(final_miss, 0.0)) | |
| # Target: log10(Pc) — the Kelvins `risk` column is already log10(Pc). | |
| # Clamp to [-20, 0] (Pc ranges from ~1e-20 to ~1) | |
| pc_log10 = float(max(min(final_risk, 0.0), -20.0)) | |
| # Domain weight: Kelvins events get full weight, Space-Track events | |
| # get reduced weight since they have sparse features (16 vs 103 columns). | |
| # This prevents the model from learning shortcuts on zero-padded features. | |
| source = event.get("source", "kelvins") | |
| domain_weight = 1.0 if source == "kelvins" else 0.3 | |
| return { | |
| "temporal": torch.tensor(temporal, dtype=torch.float32), | |
| "static": torch.tensor(static, dtype=torch.float32), | |
| "time_to_tca": torch.tensor(tca, dtype=torch.float32), | |
| "mask": torch.tensor(mask, dtype=torch.bool), | |
| "risk_label": torch.tensor(risk_label, dtype=torch.float32), | |
| "miss_log": torch.tensor(miss_log, dtype=torch.float32), | |
| "pc_log10": torch.tensor(pc_log10, dtype=torch.float32), | |
| "domain_weight": torch.tensor(domain_weight, dtype=torch.float32), | |
| } | |
| class PretrainDataset(Dataset): | |
| """Simplified CDM dataset for self-supervised pre-training (no labels needed). | |
| Returns only temporal features, static features, time_to_tca, and mask. | |
| Can process combined train+test data since labels aren't used. | |
| """ | |
| def __init__( | |
| self, | |
| df: pd.DataFrame, | |
| max_seq_len: int = MAX_SEQ_LEN, | |
| temporal_cols: list[str] = None, | |
| static_cols: list[str] = None, | |
| ): | |
| self.max_seq_len = max_seq_len | |
| self.temporal_cols = temporal_cols or find_available_features(df, TEMPORAL_FEATURES) | |
| self.static_cols = static_cols or find_available_features(df, STATIC_FEATURES) | |
| print(f" PretrainDataset — Temporal: {len(self.temporal_cols)}, Static: {len(self.static_cols)}") | |
| # Group by event_id | |
| self.events = [] | |
| for event_id, group in df.groupby("event_id"): | |
| group = group.sort_values("time_to_tca", ascending=False) | |
| self.events.append({"event_id": event_id, "group": group}) | |
| # Compute global normalization stats | |
| self.temporal_mean = df[self.temporal_cols].mean().values.astype(np.float32) | |
| self.temporal_std = df[self.temporal_cols].std().values.astype(np.float32) | |
| self.temporal_std[self.temporal_std < 1e-8] = 1.0 | |
| self.static_mean = df[self.static_cols].mean().values.astype(np.float32) | |
| self.static_std = df[self.static_cols].std().values.astype(np.float32) | |
| self.static_std[self.static_std < 1e-8] = 1.0 | |
| self.tca_mean = float(df["time_to_tca"].mean()) | |
| self.tca_std = float(df["time_to_tca"].std()) | |
| if self.tca_std < 1e-8: | |
| self.tca_std = 1.0 | |
| self._compute_delta_stats(df) | |
| def _compute_delta_stats(self, df: pd.DataFrame): | |
| """Estimate normalization stats for temporal first-order differences.""" | |
| delta_samples = [] | |
| for _, group in df.groupby("event_id"): | |
| if len(group) < 2: | |
| continue | |
| vals = group[self.temporal_cols].values.astype(np.float32) | |
| vals = np.nan_to_num(vals, nan=0.0, posinf=0.0, neginf=0.0) | |
| deltas = np.diff(vals, axis=0) | |
| delta_samples.append(deltas) | |
| if len(delta_samples) >= 2000: | |
| break | |
| if delta_samples: | |
| all_deltas = np.concatenate(delta_samples, axis=0) | |
| self.delta_mean = all_deltas.mean(axis=0).astype(np.float32) | |
| self.delta_std = all_deltas.std(axis=0).astype(np.float32) | |
| self.delta_std[self.delta_std < 1e-8] = 1.0 | |
| else: | |
| n = len(self.temporal_cols) | |
| self.delta_mean = np.zeros(n, dtype=np.float32) | |
| self.delta_std = np.ones(n, dtype=np.float32) | |
| def set_normalization(self, other): | |
| """Copy normalization stats from another dataset.""" | |
| self.temporal_mean = other.temporal_mean | |
| self.temporal_std = other.temporal_std | |
| self.static_mean = other.static_mean | |
| self.static_std = other.static_std | |
| self.tca_mean = other.tca_mean | |
| self.tca_std = other.tca_std | |
| self.delta_mean = other.delta_mean | |
| self.delta_std = other.delta_std | |
| def __len__(self): | |
| return len(self.events) | |
| def __getitem__(self, idx): | |
| event = self.events[idx] | |
| group = event["group"] | |
| # Extract temporal features | |
| temporal = group[self.temporal_cols].values.astype(np.float32) | |
| temporal = np.nan_to_num(temporal, nan=0.0, posinf=0.0, neginf=0.0) | |
| # Compute first-order differences | |
| if len(temporal) > 1: | |
| deltas = np.diff(temporal, axis=0) | |
| deltas = np.concatenate([np.zeros((1, deltas.shape[1]), dtype=np.float32), deltas], axis=0) | |
| else: | |
| deltas = np.zeros_like(temporal) | |
| # Normalize | |
| temporal = (temporal - self.temporal_mean) / self.temporal_std | |
| deltas = (deltas - self.delta_mean) / self.delta_std | |
| temporal = np.concatenate([temporal, deltas], axis=1) | |
| # Static features | |
| static = group[self.static_cols].iloc[-1].values.astype(np.float32) | |
| static = np.nan_to_num(static, nan=0.0, posinf=0.0, neginf=0.0) | |
| # Time-to-TCA | |
| tca = group["time_to_tca"].values.astype(np.float32).reshape(-1, 1) | |
| static = (static - self.static_mean) / self.static_std | |
| tca = (tca - self.tca_mean) / self.tca_std | |
| # Truncate or pad | |
| seq_len = len(temporal) | |
| if seq_len > self.max_seq_len: | |
| temporal = temporal[-self.max_seq_len:] | |
| tca = tca[-self.max_seq_len:] | |
| seq_len = self.max_seq_len | |
| pad_len = self.max_seq_len - seq_len | |
| if pad_len > 0: | |
| temporal = np.pad(temporal, ((pad_len, 0), (0, 0)), constant_values=0) | |
| tca = np.pad(tca, ((pad_len, 0), (0, 0)), constant_values=0) | |
| mask = np.zeros(self.max_seq_len, dtype=bool) | |
| mask[pad_len:] = True | |
| return { | |
| "temporal": torch.tensor(temporal, dtype=torch.float32), | |
| "static": torch.tensor(static, dtype=torch.float32), | |
| "time_to_tca": torch.tensor(tca, dtype=torch.float32), | |
| "mask": torch.tensor(mask, dtype=torch.bool), | |
| } | |
| def build_datasets( | |
| train_df: pd.DataFrame, | |
| test_df: pd.DataFrame, | |
| val_fraction: float = 0.1, | |
| use_density: bool = False, | |
| cal_fraction: float = 0.0, | |
| ) -> tuple: | |
| """ | |
| Build train, validation, and test datasets with shared normalization. | |
| Splits training data into train + val by event_id (stratified by risk). | |
| Args: | |
| train_df: Training CDM DataFrame | |
| test_df: Test CDM DataFrame | |
| val_fraction: Fraction of Kelvins training events for validation | |
| use_density: If True, include DENSITY_FEATURES in static features | |
| cal_fraction: If > 0, further split validation into val + calibration | |
| for conformal prediction. Returns 4-tuple instead of 3. | |
| Returns: | |
| If cal_fraction == 0: (train_ds, val_ds, test_ds) | |
| If cal_fraction > 0: (train_ds, val_ds, cal_ds, test_ds) | |
| """ | |
| # Compute density features if requested | |
| if use_density: | |
| from src.data.density_features import OrbitalDensityComputer | |
| density_computer = OrbitalDensityComputer() | |
| density_computer.fit(train_df) | |
| train_df = density_computer.transform(train_df) | |
| test_df = density_computer.transform(test_df) | |
| else: | |
| density_computer = None | |
| # Static columns: base (filtered to available) + optional density | |
| static_cols = [c for c in STATIC_FEATURES if c in train_df.columns] | |
| if use_density: | |
| static_cols = static_cols + [ | |
| f for f in DENSITY_FEATURES if f in train_df.columns | |
| ] | |
| # Determine risk label per event for stratification | |
| has_source = "source" in train_df.columns | |
| agg_dict = {"risk": ("risk", "last")} | |
| if has_source: | |
| agg_dict["source"] = ("source", "first") | |
| event_meta = train_df.groupby("event_id").agg(**agg_dict).reset_index() | |
| event_meta["label"] = (event_meta["risk"] > -5).astype(int) | |
| # Split validation from KELVINS-ONLY events for fair model selection. | |
| # Space-Track events (sparse features, all high-risk) inflate val metrics. | |
| if has_source: | |
| kelvins_events = event_meta[event_meta["source"] == "kelvins"] | |
| other_events = event_meta[event_meta["source"] != "kelvins"] | |
| kelvins_ids = kelvins_events["event_id"].values | |
| kelvins_labels = kelvins_events["label"].values | |
| # Stratified split on Kelvins events only | |
| k_train_ids, val_ids = train_test_split( | |
| kelvins_ids, test_size=val_fraction, stratify=kelvins_labels, random_state=42 | |
| ) | |
| # Training = Kelvins train split + all Space-Track events | |
| train_ids = np.concatenate([k_train_ids, other_events["event_id"].values]) | |
| else: | |
| event_ids = event_meta["event_id"].values | |
| labels = event_meta["label"].values | |
| train_ids, val_ids = train_test_split( | |
| event_ids, test_size=val_fraction, stratify=labels, random_state=42 | |
| ) | |
| # Further split validation into val + calibration for conformal prediction | |
| cal_ids = np.array([]) | |
| if cal_fraction > 0 and len(val_ids) > 20: | |
| val_labels = event_meta[event_meta["event_id"].isin(val_ids)]["label"].values | |
| val_ids_arr = val_ids | |
| val_ids, cal_ids = train_test_split( | |
| val_ids_arr, | |
| test_size=cal_fraction, | |
| stratify=val_labels, | |
| random_state=123, # different seed from train/val split | |
| ) | |
| train_sub = train_df[train_df["event_id"].isin(train_ids)] | |
| val_sub = train_df[train_df["event_id"].isin(val_ids)] | |
| print(f"Building datasets:") | |
| print(f" Train events: {len(train_ids)}") | |
| if has_source: | |
| n_k = train_sub[train_sub["source"] == "kelvins"]["event_id"].nunique() | |
| n_s = train_sub[train_sub["source"] != "kelvins"]["event_id"].nunique() | |
| print(f" (Kelvins: {n_k}, Space-Track: {n_s})") | |
| if use_density: | |
| print(f" Static features: {len(static_cols)} (base: {len(STATIC_FEATURES)}, " | |
| f"density: {len(static_cols) - len(STATIC_FEATURES)})") | |
| train_ds = CDMSequenceDataset(train_sub, static_cols=static_cols) | |
| print(f" Val events: {len(val_ids)} (Kelvins-only)") | |
| val_ds = CDMSequenceDataset(val_sub, static_cols=static_cols) | |
| val_ds.set_normalization(train_ds) # use training stats | |
| print(f" Test events: {test_df['event_id'].nunique()}") | |
| test_ds = CDMSequenceDataset(test_df, temporal_cols=train_ds.temporal_cols, static_cols=static_cols) | |
| test_ds.set_normalization(train_ds) | |
| # Store density computer on train_ds for checkpoint saving | |
| if density_computer is not None: | |
| train_ds._density_computer = density_computer | |
| if cal_fraction > 0 and len(cal_ids) > 0: | |
| cal_sub = train_df[train_df["event_id"].isin(cal_ids)] | |
| print(f" Cal events: {len(cal_ids)} (for conformal prediction)") | |
| cal_ds = CDMSequenceDataset(cal_sub, static_cols=static_cols) | |
| cal_ds.set_normalization(train_ds) | |
| return train_ds, val_ds, cal_ds, test_ds | |
| return train_ds, val_ds, test_ds | |