# Generated by Claude Code -- 2026-02-08 """Load and parse ESA Kelvins CDM dataset into structured formats.""" import pandas as pd import numpy as np from pathlib import Path from dataclasses import dataclass, field from typing import List, Optional @dataclass class CDMSnapshot: """A single Conjunction Data Message update.""" time_to_tca: float miss_distance: float relative_speed: float risk: float features: np.ndarray # all numeric columns as a flat vector @dataclass class ConjunctionEvent: """A complete conjunction event = sequence of CDM snapshots.""" event_id: int cdm_sequence: List[CDMSnapshot] = field(default_factory=list) risk_label: int = 0 # 1 if any CDM in sequence has high risk final_miss_distance: float = 0.0 altitude_km: float = 0.0 object_type: str = "" # Columns we use for the feature vector (numeric only, excluding IDs/targets) EXCLUDE_COLS = {"event_id", "time_to_tca", "risk", "mission_id"} def load_cdm_csv(path: Path) -> pd.DataFrame: """Load a CDM CSV and do basic cleaning.""" df = pd.read_csv(path) # Identify numeric columns for features numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() feature_cols = [c for c in numeric_cols if c not in EXCLUDE_COLS] # Fill NaN with 0 for numeric features (some covariance cols are sparse) df[feature_cols] = df[feature_cols].fillna(0) return df def load_dataset(data_dir: Path) -> tuple[pd.DataFrame, pd.DataFrame]: """Load train and test CDM DataFrames.""" # Find the CSV files (may be in subdirectory after extraction) train_candidates = list(data_dir.rglob("*train*.csv")) test_candidates = list(data_dir.rglob("*test*.csv")) if not train_candidates: raise FileNotFoundError(f"No train CSV found in {data_dir}") if not test_candidates: raise FileNotFoundError(f"No test CSV found in {data_dir}") train_path = train_candidates[0] test_path = test_candidates[0] print(f"Loading train: {train_path}") print(f"Loading test: {test_path}") train_df = load_cdm_csv(train_path) test_df = load_cdm_csv(test_path) print(f"Train: {len(train_df)} rows, {train_df['event_id'].nunique()} events") print(f"Test: {len(test_df)} rows, {test_df['event_id'].nunique()} events") return train_df, test_df def get_feature_columns(df: pd.DataFrame) -> list[str]: """Get the list of numeric feature columns (excluding IDs and targets).""" numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() return [c for c in numeric_cols if c not in EXCLUDE_COLS] def build_events(df: pd.DataFrame, feature_cols: list[str] = None) -> list[ConjunctionEvent]: """Group CDM rows by event_id into ConjunctionEvent objects (vectorized). Args: df: CDM DataFrame feature_cols: optional fixed list of feature columns (for train/test consistency) """ if feature_cols is None: feature_cols = get_feature_columns(df) else: # Ensure all requested columns exist; fill missing with 0 for col in feature_cols: if col not in df.columns: df = df.copy() df[col] = 0.0 events = [] # Pre-extract feature matrix as float64 (avoids per-row pandas indexing) feature_matrix = df[feature_cols].values # (N, F) float64 feature_matrix = np.nan_to_num(feature_matrix, nan=0.0, posinf=0.0, neginf=0.0) # Sort entire dataframe by event_id then time_to_tca descending df = df.copy() df["_row_idx"] = np.arange(len(df)) df = df.sort_values(["event_id", "time_to_tca"], ascending=[True, False]) # Determine altitude column alt_col = None for col in ["t_h_apo", "c_h_apo"]: if col in df.columns: alt_col = col break has_miss = "miss_distance" in df.columns has_speed = "relative_speed" in df.columns has_risk = "risk" in df.columns has_obj_type = "c_object_type" in df.columns for event_id, group in df.groupby("event_id", sort=True): row_indices = group["_row_idx"].values # Build CDM sequence using pre-extracted arrays cdm_seq = [] for ridx in row_indices: snap = CDMSnapshot( time_to_tca=float(df.iloc[ridx]["time_to_tca"]) if "time_to_tca" in df.columns else 0.0, miss_distance=float(df.iloc[ridx]["miss_distance"]) if has_miss else 0.0, relative_speed=float(df.iloc[ridx]["relative_speed"]) if has_speed else 0.0, risk=float(df.iloc[ridx]["risk"]) if has_risk else 0.0, features=feature_matrix[ridx].astype(np.float32), ) cdm_seq.append(snap) final_cdm = cdm_seq[-1] risk_label = 1 if final_cdm.risk > -5 else 0 alt = float(group[alt_col].iloc[-1]) if alt_col else 0.0 obj_type = str(group["c_object_type"].iloc[0]) if has_obj_type else "unknown" events.append(ConjunctionEvent( event_id=int(event_id), cdm_sequence=cdm_seq, risk_label=risk_label, final_miss_distance=final_cdm.miss_distance, altitude_km=alt, object_type=obj_type, )) n_high = sum(e.risk_label for e in events) print(f"Built {len(events)} events, {n_high} high-risk ({100*n_high/len(events):.1f}%)") return events def events_to_flat_features(events: list[ConjunctionEvent]) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Extract flat feature vectors from events for classical ML. Uses the LAST CDM snapshot (closest to TCA) + temporal trend features. Returns: (X, y_risk, y_miss) """ X_list = [] y_risk = [] y_miss = [] for event in events: seq = event.cdm_sequence last = seq[-1] base = last.features.copy() miss_values = np.array([s.miss_distance for s in seq]) risk_values = np.array([s.risk for s in seq]) tca_values = np.array([s.time_to_tca for s in seq]) n_cdms = len(seq) miss_mean = float(np.mean(miss_values)) if n_cdms > 0 else 0.0 miss_std = float(np.std(miss_values)) if n_cdms > 1 else 0.0 miss_trend = 0.0 if n_cdms > 1 and np.std(tca_values) > 0: miss_trend = float(np.polyfit(tca_values, miss_values, 1)[0]) risk_trend = 0.0 if n_cdms > 1 and np.std(tca_values) > 0: risk_trend = float(np.polyfit(tca_values, risk_values, 1)[0]) temporal_feats = np.array([ n_cdms, miss_mean, miss_std, miss_trend, risk_trend, float(miss_values[0] - miss_values[-1]) if n_cdms > 1 else 0.0, last.time_to_tca, last.relative_speed, ], dtype=np.float32) combined = np.concatenate([base, temporal_feats]) X_list.append(combined) y_risk.append(event.risk_label) y_miss.append(np.log1p(max(event.final_miss_distance, 0.0))) X = np.stack(X_list) X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0) return X, np.array(y_risk), np.array(y_miss)