# Generated by Claude Code -- 2026-02-10 """Self-supervised pre-training for the PI-TFT encoder. Masked Feature Reconstruction: mask 60% of CDM temporal features at random per timestep, train the Transformer encoder to reconstruct them. This forces the model to learn feature correlations, temporal dynamics, and static-temporal interactions from ALL CDM data (no labels needed). """ import torch import torch.nn as nn from src.model.deep import PhysicsInformedTFT class CDMMaskingStrategy(nn.Module): """Randomly mask temporal features per timestep for reconstruction pre-training. For each real timestep (respecting padding mask), replaces a fraction of the temporal features with a learnable [MASK] token. """ def __init__(self, n_temporal_features: int, mask_ratio: float = 0.6): super().__init__() self.n_temporal_features = n_temporal_features self.mask_ratio = mask_ratio # Learnable [MASK] token — one value per temporal feature self.mask_token = nn.Parameter(torch.zeros(n_temporal_features)) nn.init.normal_(self.mask_token, std=0.02) def forward( self, temporal: torch.Tensor, # (B, S, F_t) padding_mask: torch.Tensor, # (B, S) True=real, False=padding ) -> tuple[torch.Tensor, torch.Tensor]: """Apply random feature masking. Returns: masked_temporal: (B, S, F_t) with masked positions replaced by mask_token feature_mask: (B, S, F_t) bool — True where features were masked """ B, S, F = temporal.shape # Generate random mask: True = masked (to reconstruct) feature_mask = torch.rand(B, S, F, device=temporal.device) < self.mask_ratio # Only mask real timesteps (not padding) feature_mask = feature_mask & padding_mask.unsqueeze(-1) # Replace masked positions with learnable mask token masked_temporal = temporal.clone() masked_temporal[feature_mask] = self.mask_token.expand(B, S, -1)[feature_mask] return masked_temporal, feature_mask class MaskedReconstructionHead(nn.Module): """Lightweight 2-layer MLP decoder for feature reconstruction. Intentionally small to force the encoder (not the decoder) to learn rich representations. """ def __init__(self, d_model: int, n_temporal_features: int, dropout: float = 0.1): super().__init__() self.decoder = nn.Sequential( nn.LayerNorm(d_model), nn.Linear(d_model, d_model), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model, n_temporal_features), ) def forward(self, hidden: torch.Tensor) -> torch.Tensor: """Reconstruct temporal features from encoder hidden states. Args: hidden: (B, S, D) per-timestep encoder output Returns: reconstructed: (B, S, F_t) reconstructed temporal features """ return self.decoder(hidden) class PretrainingWrapper(nn.Module): """Wraps PI-TFT encoder with masking strategy and reconstruction head. Forward pass: generate mask → apply mask token → encode_sequence() → reconstruct → return reconstructed + masks. """ def __init__( self, n_temporal_features: int, n_static_features: int, d_model: int = 128, n_heads: int = 4, n_layers: int = 2, dropout: float = 0.15, mask_ratio: float = 0.6, ): super().__init__() self.encoder = PhysicsInformedTFT( n_temporal_features=n_temporal_features, n_static_features=n_static_features, d_model=d_model, n_heads=n_heads, n_layers=n_layers, dropout=dropout, ) self.masking = CDMMaskingStrategy(n_temporal_features, mask_ratio) self.reconstruction_head = MaskedReconstructionHead( d_model, n_temporal_features, dropout ) def forward( self, temporal: torch.Tensor, # (B, S, F_t) static: torch.Tensor, # (B, F_s) time_to_tca: torch.Tensor, # (B, S, 1) mask: torch.Tensor, # (B, S) True=real ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Returns: reconstructed: (B, S, F_t) reconstructed temporal features feature_mask: (B, S, F_t) bool — True where features were masked original: (B, S, F_t) original temporal features (for loss computation) """ original = temporal.clone() # Mask temporal features masked_temporal, feature_mask = self.masking(temporal, mask) # Encode masked sequence hidden, _ = self.encoder.encode_sequence( masked_temporal, static, time_to_tca, mask ) # Reconstruct reconstructed = self.reconstruction_head(hidden) return reconstructed, feature_mask, original class PretrainingLoss(nn.Module): """MSE loss computed only on masked positions.""" def forward( self, reconstructed: torch.Tensor, # (B, S, F_t) original: torch.Tensor, # (B, S, F_t) feature_mask: torch.Tensor, # (B, S, F_t) bool ) -> tuple[torch.Tensor, dict]: # MSE on masked positions only masked_diff = (reconstructed - original) ** 2 masked_diff = masked_diff[feature_mask] if masked_diff.numel() == 0: loss = torch.tensor(0.0, device=reconstructed.device, requires_grad=True) else: loss = masked_diff.mean() return loss, {"reconstruction_loss": loss.item()}