Spaces:
Sleeping
Sleeping
| # Generated by Claude Code -- 2026-02-10 | |
| """Self-supervised pre-training for the PI-TFT encoder. | |
| Masked Feature Reconstruction: mask 60% of CDM temporal features at random | |
| per timestep, train the Transformer encoder to reconstruct them. This forces | |
| the model to learn feature correlations, temporal dynamics, and | |
| static-temporal interactions from ALL CDM data (no labels needed). | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from src.model.deep import PhysicsInformedTFT | |
| class CDMMaskingStrategy(nn.Module): | |
| """Randomly mask temporal features per timestep for reconstruction pre-training. | |
| For each real timestep (respecting padding mask), replaces a fraction of the | |
| temporal features with a learnable [MASK] token. | |
| """ | |
| def __init__(self, n_temporal_features: int, mask_ratio: float = 0.6): | |
| super().__init__() | |
| self.n_temporal_features = n_temporal_features | |
| self.mask_ratio = mask_ratio | |
| # Learnable [MASK] token β one value per temporal feature | |
| self.mask_token = nn.Parameter(torch.zeros(n_temporal_features)) | |
| nn.init.normal_(self.mask_token, std=0.02) | |
| def forward( | |
| self, | |
| temporal: torch.Tensor, # (B, S, F_t) | |
| padding_mask: torch.Tensor, # (B, S) True=real, False=padding | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Apply random feature masking. | |
| Returns: | |
| masked_temporal: (B, S, F_t) with masked positions replaced by mask_token | |
| feature_mask: (B, S, F_t) bool β True where features were masked | |
| """ | |
| B, S, F = temporal.shape | |
| # Generate random mask: True = masked (to reconstruct) | |
| feature_mask = torch.rand(B, S, F, device=temporal.device) < self.mask_ratio | |
| # Only mask real timesteps (not padding) | |
| feature_mask = feature_mask & padding_mask.unsqueeze(-1) | |
| # Replace masked positions with learnable mask token | |
| masked_temporal = temporal.clone() | |
| masked_temporal[feature_mask] = self.mask_token.expand(B, S, -1)[feature_mask] | |
| return masked_temporal, feature_mask | |
| class MaskedReconstructionHead(nn.Module): | |
| """Lightweight 2-layer MLP decoder for feature reconstruction. | |
| Intentionally small to force the encoder (not the decoder) to learn | |
| rich representations. | |
| """ | |
| def __init__(self, d_model: int, n_temporal_features: int, dropout: float = 0.1): | |
| super().__init__() | |
| self.decoder = nn.Sequential( | |
| nn.LayerNorm(d_model), | |
| nn.Linear(d_model, d_model), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(d_model, n_temporal_features), | |
| ) | |
| def forward(self, hidden: torch.Tensor) -> torch.Tensor: | |
| """Reconstruct temporal features from encoder hidden states. | |
| Args: | |
| hidden: (B, S, D) per-timestep encoder output | |
| Returns: | |
| reconstructed: (B, S, F_t) reconstructed temporal features | |
| """ | |
| return self.decoder(hidden) | |
| class PretrainingWrapper(nn.Module): | |
| """Wraps PI-TFT encoder with masking strategy and reconstruction head. | |
| Forward pass: generate mask β apply mask token β encode_sequence() β | |
| reconstruct β return reconstructed + masks. | |
| """ | |
| def __init__( | |
| self, | |
| n_temporal_features: int, | |
| n_static_features: int, | |
| d_model: int = 128, | |
| n_heads: int = 4, | |
| n_layers: int = 2, | |
| dropout: float = 0.15, | |
| mask_ratio: float = 0.6, | |
| ): | |
| super().__init__() | |
| self.encoder = PhysicsInformedTFT( | |
| n_temporal_features=n_temporal_features, | |
| n_static_features=n_static_features, | |
| d_model=d_model, | |
| n_heads=n_heads, | |
| n_layers=n_layers, | |
| dropout=dropout, | |
| ) | |
| self.masking = CDMMaskingStrategy(n_temporal_features, mask_ratio) | |
| self.reconstruction_head = MaskedReconstructionHead( | |
| d_model, n_temporal_features, dropout | |
| ) | |
| def forward( | |
| self, | |
| temporal: torch.Tensor, # (B, S, F_t) | |
| static: torch.Tensor, # (B, F_s) | |
| time_to_tca: torch.Tensor, # (B, S, 1) | |
| mask: torch.Tensor, # (B, S) True=real | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Returns: | |
| reconstructed: (B, S, F_t) reconstructed temporal features | |
| feature_mask: (B, S, F_t) bool β True where features were masked | |
| original: (B, S, F_t) original temporal features (for loss computation) | |
| """ | |
| original = temporal.clone() | |
| # Mask temporal features | |
| masked_temporal, feature_mask = self.masking(temporal, mask) | |
| # Encode masked sequence | |
| hidden, _ = self.encoder.encode_sequence( | |
| masked_temporal, static, time_to_tca, mask | |
| ) | |
| # Reconstruct | |
| reconstructed = self.reconstruction_head(hidden) | |
| return reconstructed, feature_mask, original | |
| class PretrainingLoss(nn.Module): | |
| """MSE loss computed only on masked positions.""" | |
| def forward( | |
| self, | |
| reconstructed: torch.Tensor, # (B, S, F_t) | |
| original: torch.Tensor, # (B, S, F_t) | |
| feature_mask: torch.Tensor, # (B, S, F_t) bool | |
| ) -> tuple[torch.Tensor, dict]: | |
| # MSE on masked positions only | |
| masked_diff = (reconstructed - original) ** 2 | |
| masked_diff = masked_diff[feature_mask] | |
| if masked_diff.numel() == 0: | |
| loss = torch.tensor(0.0, device=reconstructed.device, requires_grad=True) | |
| else: | |
| loss = masked_diff.mean() | |
| return loss, {"reconstruction_loss": loss.item()} | |