Spaces:
Sleeping
Sleeping
| # Generated by Claude Code -- 2026-02-08 | |
| """Model 3: Physics-Informed Temporal Fusion Transformer (PI-TFT). | |
| Architecture overview (think of it like reading serial lab values): | |
| 1. VARIABLE SELECTION: Not all 22 CDM features matter equally. The model | |
| learns attention weights over features -- e.g., miss_distance and | |
| covariance shrinkage rate might matter more than raw orbital elements. | |
| This is like a doctor learning which labs to focus on. | |
| 2. STATIC CONTEXT: Object properties (altitude, size, eccentricity) don't | |
| change between CDM updates. They're encoded once and injected as context | |
| into the temporal processing. Like knowing the patient's age and history. | |
| 3. CONTINUOUS TIME EMBEDDING: CDMs arrive at irregular intervals (not evenly | |
| spaced). Instead of positional encoding (position 1, 2, 3...), we embed | |
| the actual time_to_tca value. The model knows "this CDM was 3.2 days | |
| before closest approach" vs "this one was 0.5 days before." | |
| 4. TEMPORAL SELF-ATTENTION: The Transformer reads the full CDM sequence and | |
| learns which updates were most informative. A sudden miss distance drop | |
| at day -2 gets more attention than a stable reading at day -5. | |
| 5. PREDICTION HEADS: The final hidden state (from the most recent CDM) | |
| feeds into two prediction heads: | |
| - Risk classifier: sigmoid probability of high-risk collision | |
| - Miss distance regressor: predicted log(miss distance in km) | |
| 6. PHYSICS LOSS: The training loss includes a penalty when the model predicts | |
| a miss distance BELOW the Minimum Orbital Intersection Distance (MOID). | |
| MOID is the closest the two orbits can geometrically get. Predicting | |
| closer than MOID is physically impossible (without a maneuver), so we | |
| penalize it. This is like penalizing a model for predicting negative | |
| blood pressure -- constraining outputs to the physically possible range. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| class GatedResidualNetwork(nn.Module): | |
| """ | |
| Gated skip connection with ELU activation and layer norm. | |
| Think of this as a "smart residual block" -- it learns how much of the | |
| transformed input to mix with the original. The gate (sigmoid) controls | |
| this: gate=0 means pass through unchanged, gate=1 means fully transformed. | |
| """ | |
| def __init__(self, d_model: int, d_hidden: int = None, dropout: float = 0.1): | |
| super().__init__() | |
| d_hidden = d_hidden or d_model | |
| self.fc1 = nn.Linear(d_model, d_hidden) | |
| self.fc2 = nn.Linear(d_hidden, d_model) | |
| self.gate_fc = nn.Linear(d_hidden, d_model) | |
| self.norm = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| residual = x | |
| h = F.elu(self.fc1(x)) | |
| h = self.dropout(h) | |
| transform = self.fc2(h) | |
| gate = torch.sigmoid(self.gate_fc(h)) | |
| return self.norm(residual + gate * transform) | |
| class VariableSelectionNetwork(nn.Module): | |
| """ | |
| Learns which input features matter most via softmax attention. | |
| For N input features, produces N attention weights that sum to 1. | |
| Each feature is independently projected to d_model, then weighted | |
| and summed. The weights are interpretable -- they tell you which | |
| CDM columns the model found most predictive. | |
| """ | |
| def __init__(self, n_features: int, d_model: int, dropout: float = 0.1): | |
| super().__init__() | |
| self.n_features = n_features | |
| self.d_model = d_model | |
| # Each feature gets its own linear projection: scalar -> d_model vector | |
| self.feature_projections = nn.ModuleList([ | |
| nn.Linear(1, d_model) for _ in range(n_features) | |
| ]) | |
| # Gating network: takes flattened projections -> feature weights | |
| self.gate_network = nn.Sequential( | |
| nn.Linear(n_features * d_model, n_features), | |
| nn.Softmax(dim=-1), | |
| ) | |
| self.grn = GatedResidualNetwork(d_model, dropout=dropout) | |
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Args: | |
| x: (..., n_features) — can be (B, F) for static or (B, S, F) for temporal | |
| Returns: | |
| output: (..., d_model) — weighted combination of projected features | |
| weights: (..., n_features) — attention weights (sum to 1) | |
| """ | |
| # Project each feature independently | |
| # x[..., i:i+1] is the i-th feature, shape (..., 1) | |
| projected = [proj(x[..., i:i+1]) for i, proj in enumerate(self.feature_projections)] | |
| # projected[i] shape: (..., d_model) | |
| # Stack for gating: (..., n_features, d_model) | |
| stacked = torch.stack(projected, dim=-2) | |
| # Flatten for gate computation: (..., n_features * d_model) | |
| flat = stacked.reshape(*stacked.shape[:-2], -1) | |
| weights = self.gate_network(flat) # (..., n_features) | |
| # Weighted sum: (..., d_model) | |
| output = (stacked * weights.unsqueeze(-1)).sum(dim=-2) | |
| output = self.grn(output) | |
| return output, weights | |
| class PhysicsInformedTFT(nn.Module): | |
| """ | |
| Physics-Informed Temporal Fusion Transformer for conjunction assessment. | |
| Input flow: | |
| temporal_features (B, S, F_t) → Variable Selection → time embedding → self-attention → attention pool → heads | |
| static_features (B, F_s) → Variable Selection → context injection ↗ | |
| Output: | |
| risk_logit: (B, 1) — raw logit for risk classification (apply sigmoid for probability) | |
| miss_log: (B, 1) — predicted log1p(miss_distance_km) | |
| pc_log10: (B, 1) — predicted log10(Pc) collision probability (when has_pc_head=True) | |
| feature_weights: (B, S, F_t) — which temporal features mattered | |
| """ | |
| def __init__( | |
| self, | |
| n_temporal_features: int, | |
| n_static_features: int, | |
| d_model: int = 128, | |
| n_heads: int = 4, | |
| n_layers: int = 2, | |
| dropout: float = 0.15, | |
| max_seq_len: int = 30, | |
| ): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.max_seq_len = max_seq_len | |
| # --- Variable Selection Networks --- | |
| self.temporal_vsn = VariableSelectionNetwork(n_temporal_features, d_model, dropout) | |
| self.static_vsn = VariableSelectionNetwork(n_static_features, d_model, dropout) | |
| # --- Static context encoding --- | |
| self.static_encoder = nn.Sequential( | |
| nn.Linear(d_model, d_model), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| ) | |
| # Static -> enrichment vector that's added to each temporal step | |
| self.static_to_enrichment = nn.Linear(d_model, d_model) | |
| # --- Continuous time embedding --- | |
| # Instead of fixed positional encoding, we embed the actual time_to_tca | |
| self.time_embedding = nn.Sequential( | |
| nn.Linear(1, d_model // 2), | |
| nn.GELU(), | |
| nn.Linear(d_model // 2, d_model), | |
| ) | |
| # --- Transformer encoder layers --- | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=d_model, | |
| nhead=n_heads, | |
| dim_feedforward=d_model * 2, | |
| dropout=dropout, | |
| activation="gelu", | |
| batch_first=True, | |
| norm_first=True, | |
| ) | |
| self.transformer_encoder = nn.TransformerEncoder( | |
| encoder_layer, num_layers=n_layers | |
| ) | |
| # --- Pre/post attention processing --- | |
| self.pre_attn_grn = GatedResidualNetwork(d_model, dropout=dropout) | |
| self.post_attn_grn = GatedResidualNetwork(d_model, dropout=dropout) | |
| # --- Attention-weighted pooling --- | |
| # Learns which time steps matter most instead of just taking the last one. | |
| # Softmax attention over all real positions, with padding masked out. | |
| self.pool_attention = nn.Sequential( | |
| nn.Linear(d_model, d_model // 2), | |
| nn.Tanh(), | |
| nn.Linear(d_model // 2, 1), | |
| ) | |
| # --- Prediction heads --- | |
| self.risk_head = nn.Sequential( | |
| nn.LayerNorm(d_model), | |
| nn.Linear(d_model, 64), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(64, 1), | |
| ) | |
| self.miss_head = nn.Sequential( | |
| nn.LayerNorm(d_model), | |
| nn.Linear(d_model, 64), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(64, 1), | |
| ) | |
| # --- Collision probability head --- | |
| # Predicts log10(Pc) directly instead of binary risk classification. | |
| # Pc ranges from ~1e-20 to ~1e-1, so log10 scale maps to [-20, -1]. | |
| # The Kelvins `risk` column is already log10(Pc). | |
| self.pc_head = nn.Sequential( | |
| nn.LayerNorm(d_model), | |
| nn.Linear(d_model, 64), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(64, 1), | |
| ) | |
| def encode_sequence( | |
| self, | |
| temporal_features: torch.Tensor, # (B, S, F_t) | |
| static_features: torch.Tensor, # (B, F_s) | |
| time_to_tca: torch.Tensor, # (B, S, 1) | |
| mask: torch.Tensor, # (B, S) — True for real, False for padding | |
| ): | |
| """Encode CDM sequence into per-timestep hidden states. | |
| Returns: | |
| hidden: (B, S, D) per-timestep representations after Transformer | |
| temporal_weights: (B, S, F_t) variable selection attention weights | |
| """ | |
| # 1. Variable selection -- learn which features matter | |
| temporal_selected, temporal_weights = self.temporal_vsn(temporal_features) | |
| # temporal_selected: (B, S, D), temporal_weights: (B, S, F_t) | |
| static_selected, static_weights = self.static_vsn(static_features) | |
| # static_selected: (B, D) | |
| # 2. Static context -- compute enrichment vector | |
| static_ctx = self.static_encoder(static_selected) # (B, D) | |
| enrichment = self.static_to_enrichment(static_ctx) # (B, D) | |
| # 3. Continuous time embedding | |
| t_embed = self.time_embedding(time_to_tca) # (B, S, D) | |
| # 4. Combine: temporal + time + static context | |
| x = temporal_selected + t_embed + enrichment.unsqueeze(1) | |
| # 5. Pre-attention GRN | |
| x = self.pre_attn_grn(x) | |
| # 6. Transformer self-attention | |
| # Convert mask: True=real -> need to invert for PyTorch's src_key_padding_mask | |
| # PyTorch expects True=ignore, so we flip | |
| padding_mask = ~mask # (B, S), True = pad position to ignore | |
| x = self.transformer_encoder(x, src_key_padding_mask=padding_mask) | |
| # 7. Post-attention GRN | |
| x = self.post_attn_grn(x) | |
| return x, temporal_weights | |
| def forward( | |
| self, | |
| temporal_features: torch.Tensor, # (B, S, F_t) | |
| static_features: torch.Tensor, # (B, F_s) | |
| time_to_tca: torch.Tensor, # (B, S, 1) | |
| mask: torch.Tensor, # (B, S) — True for real, False for padding | |
| ): | |
| B, S, _ = temporal_features.shape | |
| # Steps 1-7: encode sequence into per-timestep hidden states | |
| x, temporal_weights = self.encode_sequence( | |
| temporal_features, static_features, time_to_tca, mask | |
| ) | |
| # 8. Attention-weighted pooling over all real positions | |
| # Instead of just the last CDM, learn which time steps matter most | |
| attn_scores = self.pool_attention(x).squeeze(-1) # (B, S) | |
| # Mask padding positions with -inf so they get zero attention | |
| attn_scores = attn_scores.masked_fill(~mask, float("-inf")) | |
| attn_weights = F.softmax(attn_scores, dim=-1) # (B, S) | |
| # Handle all-padding edge case (shouldn't happen but be safe) | |
| attn_weights = attn_weights.nan_to_num(0.0) | |
| x_pooled = (x * attn_weights.unsqueeze(-1)).sum(dim=1) # (B, D) | |
| # 9. Prediction heads | |
| risk_logit = self.risk_head(x_pooled) # (B, 1) | |
| miss_log = self.miss_head(x_pooled) # (B, 1) | |
| pc_log10 = self.pc_head(x_pooled) # (B, 1) — log10(Pc) | |
| return risk_logit, miss_log, pc_log10, temporal_weights | |
| def count_parameters(self) -> int: | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| class SigmoidFocalLoss(nn.Module): | |
| """ | |
| Focal Loss for binary classification (Lin et al., 2017). | |
| Down-weights well-classified examples so the model focuses on hard cases. | |
| FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t) | |
| With gamma=0, this reduces to standard weighted BCE. | |
| With gamma=2, easy examples (p_t > 0.9) get ~100x less weight. | |
| """ | |
| def __init__(self, alpha: float = 0.75, gamma: float = 2.0, reduction: str = "mean"): | |
| super().__init__() | |
| self.alpha = alpha | |
| self.gamma = gamma | |
| self.reduction = reduction | |
| def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | |
| p = torch.sigmoid(logits) | |
| # p_t = probability of the true class | |
| p_t = targets * p + (1 - targets) * (1 - p) | |
| # alpha_t = alpha for positive class, (1-alpha) for negative | |
| alpha_t = targets * self.alpha + (1 - targets) * (1 - self.alpha) | |
| # focal modulator: (1 - p_t)^gamma | |
| focal_weight = (1 - p_t) ** self.gamma | |
| # BCE per-element (numerically stable via log-sum-exp) | |
| bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none") | |
| loss = alpha_t * focal_weight * bce | |
| if self.reduction == "none": | |
| return loss | |
| return loss.mean() | |
| class PhysicsInformedLoss(nn.Module): | |
| """ | |
| Combined task loss + physics regularization. | |
| Total loss = risk_weight * FocalLoss(risk) + miss_weight * MSE(miss_distance) | |
| + pc_weight * MSE(log10_Pc) + physics_weight * ReLU(MOID - predicted_miss) | |
| The physics term: MOID (Minimum Orbital Intersection Distance) is the | |
| geometric minimum distance between two orbits. The actual miss distance | |
| at closest approach CANNOT be less than MOID (without a maneuver). | |
| If the model predicts miss < MOID, we penalize it. | |
| The Pc term: direct regression on log10(collision probability). The Kelvins | |
| `risk` column is log10(Pc), giving us 162K labeled examples. This lets | |
| the model output calibrated collision probabilities, not just binary risk. | |
| For the Kelvins dataset, we approximate MOID from the orbital elements | |
| in the CDM features. When MOID isn't available, the physics term is 0. | |
| """ | |
| def __init__( | |
| self, | |
| risk_weight: float = 1.0, | |
| miss_weight: float = 0.1, | |
| pc_weight: float = 0.3, | |
| physics_weight: float = 0.2, | |
| pos_weight: float = 50.0, | |
| use_focal: bool = False, | |
| focal_alpha: float = 0.75, | |
| focal_gamma: float = 2.0, | |
| ): | |
| super().__init__() | |
| self.risk_weight = risk_weight | |
| self.miss_weight = miss_weight | |
| self.pc_weight = pc_weight | |
| self.physics_weight = physics_weight | |
| if use_focal: | |
| self.risk_loss = SigmoidFocalLoss(alpha=focal_alpha, gamma=focal_gamma) | |
| else: | |
| self.risk_loss = nn.BCEWithLogitsLoss( | |
| pos_weight=torch.tensor(pos_weight) | |
| ) | |
| self.miss_loss = nn.MSELoss() | |
| def forward( | |
| self, | |
| risk_logit: torch.Tensor, # (B, 1) | |
| miss_pred_log: torch.Tensor, # (B, 1) | |
| risk_target: torch.Tensor, # (B,) | |
| miss_target_log: torch.Tensor, # (B,) | |
| pc_pred_log10: torch.Tensor = None, # (B, 1) predicted log10(Pc) | |
| pc_target_log10: torch.Tensor = None, # (B,) target log10(Pc) | |
| moid_log: torch.Tensor = None, # (B,) optional, log1p(MOID_km) | |
| domain_weight: torch.Tensor = None, # (B,) per-sample weight | |
| ) -> tuple[torch.Tensor, dict]: | |
| # Risk classification loss (BCE with class weighting) | |
| if domain_weight is not None and not isinstance(self.risk_loss, SigmoidFocalLoss): | |
| # Per-sample weighted BCE: compute element-wise then weight | |
| bce_per_sample = F.binary_cross_entropy_with_logits( | |
| risk_logit.squeeze(-1), risk_target, | |
| pos_weight=self.risk_loss.pos_weight.to(risk_logit.device), | |
| reduction="none", | |
| ) | |
| L_risk = (bce_per_sample * domain_weight).mean() | |
| else: | |
| L_risk = self.risk_loss(risk_logit.squeeze(-1), risk_target) | |
| # Miss distance regression loss — also domain-weighted | |
| miss_residual = (miss_pred_log.squeeze(-1) - miss_target_log) ** 2 | |
| if domain_weight is not None: | |
| L_miss = (miss_residual * domain_weight).mean() | |
| else: | |
| L_miss = miss_residual.mean() | |
| # Collision probability regression loss | |
| L_pc = torch.tensor(0.0, device=risk_logit.device) | |
| if pc_pred_log10 is not None and pc_target_log10 is not None: | |
| pc_residual = (pc_pred_log10.squeeze(-1) - pc_target_log10) ** 2 | |
| if domain_weight is not None: | |
| L_pc = (pc_residual * domain_weight).mean() | |
| else: | |
| L_pc = pc_residual.mean() | |
| # Physics constraint: predicted miss >= MOID | |
| L_physics = torch.tensor(0.0, device=risk_logit.device) | |
| if moid_log is not None: | |
| # Violation = how much below MOID the prediction is | |
| violation = F.relu(moid_log - miss_pred_log.squeeze(-1)) | |
| L_physics = violation.mean() | |
| total = (self.risk_weight * L_risk | |
| + self.miss_weight * L_miss | |
| + self.pc_weight * L_pc | |
| + self.physics_weight * L_physics) | |
| metrics = { | |
| "loss": total.item(), | |
| "risk_loss": L_risk.item(), | |
| "miss_loss": L_miss.item(), | |
| "pc_loss": L_pc.item(), | |
| "physics_loss": L_physics.item(), | |
| } | |
| return total, metrics | |