Spaces:
Sleeping
Sleeping
File size: 5,835 Bytes
a4b5ecb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | # Generated by Claude Code -- 2026-02-10
"""Self-supervised pre-training for the PI-TFT encoder.
Masked Feature Reconstruction: mask 60% of CDM temporal features at random
per timestep, train the Transformer encoder to reconstruct them. This forces
the model to learn feature correlations, temporal dynamics, and
static-temporal interactions from ALL CDM data (no labels needed).
"""
import torch
import torch.nn as nn
from src.model.deep import PhysicsInformedTFT
class CDMMaskingStrategy(nn.Module):
"""Randomly mask temporal features per timestep for reconstruction pre-training.
For each real timestep (respecting padding mask), replaces a fraction of the
temporal features with a learnable [MASK] token.
"""
def __init__(self, n_temporal_features: int, mask_ratio: float = 0.6):
super().__init__()
self.n_temporal_features = n_temporal_features
self.mask_ratio = mask_ratio
# Learnable [MASK] token β one value per temporal feature
self.mask_token = nn.Parameter(torch.zeros(n_temporal_features))
nn.init.normal_(self.mask_token, std=0.02)
def forward(
self,
temporal: torch.Tensor, # (B, S, F_t)
padding_mask: torch.Tensor, # (B, S) True=real, False=padding
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply random feature masking.
Returns:
masked_temporal: (B, S, F_t) with masked positions replaced by mask_token
feature_mask: (B, S, F_t) bool β True where features were masked
"""
B, S, F = temporal.shape
# Generate random mask: True = masked (to reconstruct)
feature_mask = torch.rand(B, S, F, device=temporal.device) < self.mask_ratio
# Only mask real timesteps (not padding)
feature_mask = feature_mask & padding_mask.unsqueeze(-1)
# Replace masked positions with learnable mask token
masked_temporal = temporal.clone()
masked_temporal[feature_mask] = self.mask_token.expand(B, S, -1)[feature_mask]
return masked_temporal, feature_mask
class MaskedReconstructionHead(nn.Module):
"""Lightweight 2-layer MLP decoder for feature reconstruction.
Intentionally small to force the encoder (not the decoder) to learn
rich representations.
"""
def __init__(self, d_model: int, n_temporal_features: int, dropout: float = 0.1):
super().__init__()
self.decoder = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_model, n_temporal_features),
)
def forward(self, hidden: torch.Tensor) -> torch.Tensor:
"""Reconstruct temporal features from encoder hidden states.
Args:
hidden: (B, S, D) per-timestep encoder output
Returns:
reconstructed: (B, S, F_t) reconstructed temporal features
"""
return self.decoder(hidden)
class PretrainingWrapper(nn.Module):
"""Wraps PI-TFT encoder with masking strategy and reconstruction head.
Forward pass: generate mask β apply mask token β encode_sequence() β
reconstruct β return reconstructed + masks.
"""
def __init__(
self,
n_temporal_features: int,
n_static_features: int,
d_model: int = 128,
n_heads: int = 4,
n_layers: int = 2,
dropout: float = 0.15,
mask_ratio: float = 0.6,
):
super().__init__()
self.encoder = PhysicsInformedTFT(
n_temporal_features=n_temporal_features,
n_static_features=n_static_features,
d_model=d_model,
n_heads=n_heads,
n_layers=n_layers,
dropout=dropout,
)
self.masking = CDMMaskingStrategy(n_temporal_features, mask_ratio)
self.reconstruction_head = MaskedReconstructionHead(
d_model, n_temporal_features, dropout
)
def forward(
self,
temporal: torch.Tensor, # (B, S, F_t)
static: torch.Tensor, # (B, F_s)
time_to_tca: torch.Tensor, # (B, S, 1)
mask: torch.Tensor, # (B, S) True=real
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
reconstructed: (B, S, F_t) reconstructed temporal features
feature_mask: (B, S, F_t) bool β True where features were masked
original: (B, S, F_t) original temporal features (for loss computation)
"""
original = temporal.clone()
# Mask temporal features
masked_temporal, feature_mask = self.masking(temporal, mask)
# Encode masked sequence
hidden, _ = self.encoder.encode_sequence(
masked_temporal, static, time_to_tca, mask
)
# Reconstruct
reconstructed = self.reconstruction_head(hidden)
return reconstructed, feature_mask, original
class PretrainingLoss(nn.Module):
"""MSE loss computed only on masked positions."""
def forward(
self,
reconstructed: torch.Tensor, # (B, S, F_t)
original: torch.Tensor, # (B, S, F_t)
feature_mask: torch.Tensor, # (B, S, F_t) bool
) -> tuple[torch.Tensor, dict]:
# MSE on masked positions only
masked_diff = (reconstructed - original) ** 2
masked_diff = masked_diff[feature_mask]
if masked_diff.numel() == 0:
loss = torch.tensor(0.0, device=reconstructed.device, requires_grad=True)
else:
loss = masked_diff.mean()
return loss, {"reconstruction_loss": loss.item()}
|