panacea-api / src /model /pretrain.py
DTanzillo's picture
Upload folder using huggingface_hub
a4b5ecb verified
# Generated by Claude Code -- 2026-02-10
"""Self-supervised pre-training for the PI-TFT encoder.
Masked Feature Reconstruction: mask 60% of CDM temporal features at random
per timestep, train the Transformer encoder to reconstruct them. This forces
the model to learn feature correlations, temporal dynamics, and
static-temporal interactions from ALL CDM data (no labels needed).
"""
import torch
import torch.nn as nn
from src.model.deep import PhysicsInformedTFT
class CDMMaskingStrategy(nn.Module):
"""Randomly mask temporal features per timestep for reconstruction pre-training.
For each real timestep (respecting padding mask), replaces a fraction of the
temporal features with a learnable [MASK] token.
"""
def __init__(self, n_temporal_features: int, mask_ratio: float = 0.6):
super().__init__()
self.n_temporal_features = n_temporal_features
self.mask_ratio = mask_ratio
# Learnable [MASK] token β€” one value per temporal feature
self.mask_token = nn.Parameter(torch.zeros(n_temporal_features))
nn.init.normal_(self.mask_token, std=0.02)
def forward(
self,
temporal: torch.Tensor, # (B, S, F_t)
padding_mask: torch.Tensor, # (B, S) True=real, False=padding
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply random feature masking.
Returns:
masked_temporal: (B, S, F_t) with masked positions replaced by mask_token
feature_mask: (B, S, F_t) bool β€” True where features were masked
"""
B, S, F = temporal.shape
# Generate random mask: True = masked (to reconstruct)
feature_mask = torch.rand(B, S, F, device=temporal.device) < self.mask_ratio
# Only mask real timesteps (not padding)
feature_mask = feature_mask & padding_mask.unsqueeze(-1)
# Replace masked positions with learnable mask token
masked_temporal = temporal.clone()
masked_temporal[feature_mask] = self.mask_token.expand(B, S, -1)[feature_mask]
return masked_temporal, feature_mask
class MaskedReconstructionHead(nn.Module):
"""Lightweight 2-layer MLP decoder for feature reconstruction.
Intentionally small to force the encoder (not the decoder) to learn
rich representations.
"""
def __init__(self, d_model: int, n_temporal_features: int, dropout: float = 0.1):
super().__init__()
self.decoder = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_model, n_temporal_features),
)
def forward(self, hidden: torch.Tensor) -> torch.Tensor:
"""Reconstruct temporal features from encoder hidden states.
Args:
hidden: (B, S, D) per-timestep encoder output
Returns:
reconstructed: (B, S, F_t) reconstructed temporal features
"""
return self.decoder(hidden)
class PretrainingWrapper(nn.Module):
"""Wraps PI-TFT encoder with masking strategy and reconstruction head.
Forward pass: generate mask β†’ apply mask token β†’ encode_sequence() β†’
reconstruct β†’ return reconstructed + masks.
"""
def __init__(
self,
n_temporal_features: int,
n_static_features: int,
d_model: int = 128,
n_heads: int = 4,
n_layers: int = 2,
dropout: float = 0.15,
mask_ratio: float = 0.6,
):
super().__init__()
self.encoder = PhysicsInformedTFT(
n_temporal_features=n_temporal_features,
n_static_features=n_static_features,
d_model=d_model,
n_heads=n_heads,
n_layers=n_layers,
dropout=dropout,
)
self.masking = CDMMaskingStrategy(n_temporal_features, mask_ratio)
self.reconstruction_head = MaskedReconstructionHead(
d_model, n_temporal_features, dropout
)
def forward(
self,
temporal: torch.Tensor, # (B, S, F_t)
static: torch.Tensor, # (B, F_s)
time_to_tca: torch.Tensor, # (B, S, 1)
mask: torch.Tensor, # (B, S) True=real
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
reconstructed: (B, S, F_t) reconstructed temporal features
feature_mask: (B, S, F_t) bool β€” True where features were masked
original: (B, S, F_t) original temporal features (for loss computation)
"""
original = temporal.clone()
# Mask temporal features
masked_temporal, feature_mask = self.masking(temporal, mask)
# Encode masked sequence
hidden, _ = self.encoder.encode_sequence(
masked_temporal, static, time_to_tca, mask
)
# Reconstruct
reconstructed = self.reconstruction_head(hidden)
return reconstructed, feature_mask, original
class PretrainingLoss(nn.Module):
"""MSE loss computed only on masked positions."""
def forward(
self,
reconstructed: torch.Tensor, # (B, S, F_t)
original: torch.Tensor, # (B, S, F_t)
feature_mask: torch.Tensor, # (B, S, F_t) bool
) -> tuple[torch.Tensor, dict]:
# MSE on masked positions only
masked_diff = (reconstructed - original) ** 2
masked_diff = masked_diff[feature_mask]
if masked_diff.numel() == 0:
loss = torch.tensor(0.0, device=reconstructed.device, requires_grad=True)
else:
loss = masked_diff.mean()
return loss, {"reconstruction_loss": loss.item()}