File size: 5,835 Bytes
a4b5ecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Generated by Claude Code -- 2026-02-10
"""Self-supervised pre-training for the PI-TFT encoder.



Masked Feature Reconstruction: mask 60% of CDM temporal features at random

per timestep, train the Transformer encoder to reconstruct them. This forces

the model to learn feature correlations, temporal dynamics, and

static-temporal interactions from ALL CDM data (no labels needed).

"""

import torch
import torch.nn as nn

from src.model.deep import PhysicsInformedTFT


class CDMMaskingStrategy(nn.Module):
    """Randomly mask temporal features per timestep for reconstruction pre-training.



    For each real timestep (respecting padding mask), replaces a fraction of the

    temporal features with a learnable [MASK] token.

    """

    def __init__(self, n_temporal_features: int, mask_ratio: float = 0.6):
        super().__init__()
        self.n_temporal_features = n_temporal_features
        self.mask_ratio = mask_ratio
        # Learnable [MASK] token β€” one value per temporal feature
        self.mask_token = nn.Parameter(torch.zeros(n_temporal_features))
        nn.init.normal_(self.mask_token, std=0.02)

    def forward(

        self,

        temporal: torch.Tensor,  # (B, S, F_t)

        padding_mask: torch.Tensor,  # (B, S) True=real, False=padding

    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Apply random feature masking.



        Returns:

            masked_temporal: (B, S, F_t) with masked positions replaced by mask_token

            feature_mask: (B, S, F_t) bool β€” True where features were masked

        """
        B, S, F = temporal.shape

        # Generate random mask: True = masked (to reconstruct)
        feature_mask = torch.rand(B, S, F, device=temporal.device) < self.mask_ratio

        # Only mask real timesteps (not padding)
        feature_mask = feature_mask & padding_mask.unsqueeze(-1)

        # Replace masked positions with learnable mask token
        masked_temporal = temporal.clone()
        masked_temporal[feature_mask] = self.mask_token.expand(B, S, -1)[feature_mask]

        return masked_temporal, feature_mask


class MaskedReconstructionHead(nn.Module):
    """Lightweight 2-layer MLP decoder for feature reconstruction.



    Intentionally small to force the encoder (not the decoder) to learn

    rich representations.

    """

    def __init__(self, d_model: int, n_temporal_features: int, dropout: float = 0.1):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, n_temporal_features),
        )

    def forward(self, hidden: torch.Tensor) -> torch.Tensor:
        """Reconstruct temporal features from encoder hidden states.



        Args:

            hidden: (B, S, D) per-timestep encoder output



        Returns:

            reconstructed: (B, S, F_t) reconstructed temporal features

        """
        return self.decoder(hidden)


class PretrainingWrapper(nn.Module):
    """Wraps PI-TFT encoder with masking strategy and reconstruction head.



    Forward pass: generate mask β†’ apply mask token β†’ encode_sequence() β†’

    reconstruct β†’ return reconstructed + masks.

    """

    def __init__(

        self,

        n_temporal_features: int,

        n_static_features: int,

        d_model: int = 128,

        n_heads: int = 4,

        n_layers: int = 2,

        dropout: float = 0.15,

        mask_ratio: float = 0.6,

    ):
        super().__init__()
        self.encoder = PhysicsInformedTFT(
            n_temporal_features=n_temporal_features,
            n_static_features=n_static_features,
            d_model=d_model,
            n_heads=n_heads,
            n_layers=n_layers,
            dropout=dropout,
        )
        self.masking = CDMMaskingStrategy(n_temporal_features, mask_ratio)
        self.reconstruction_head = MaskedReconstructionHead(
            d_model, n_temporal_features, dropout
        )

    def forward(

        self,

        temporal: torch.Tensor,    # (B, S, F_t)

        static: torch.Tensor,      # (B, F_s)

        time_to_tca: torch.Tensor, # (B, S, 1)

        mask: torch.Tensor,        # (B, S) True=real

    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """

        Returns:

            reconstructed: (B, S, F_t) reconstructed temporal features

            feature_mask: (B, S, F_t) bool β€” True where features were masked

            original: (B, S, F_t) original temporal features (for loss computation)

        """
        original = temporal.clone()

        # Mask temporal features
        masked_temporal, feature_mask = self.masking(temporal, mask)

        # Encode masked sequence
        hidden, _ = self.encoder.encode_sequence(
            masked_temporal, static, time_to_tca, mask
        )

        # Reconstruct
        reconstructed = self.reconstruction_head(hidden)

        return reconstructed, feature_mask, original


class PretrainingLoss(nn.Module):
    """MSE loss computed only on masked positions."""

    def forward(

        self,

        reconstructed: torch.Tensor,  # (B, S, F_t)

        original: torch.Tensor,       # (B, S, F_t)

        feature_mask: torch.Tensor,   # (B, S, F_t) bool

    ) -> tuple[torch.Tensor, dict]:
        # MSE on masked positions only
        masked_diff = (reconstructed - original) ** 2
        masked_diff = masked_diff[feature_mask]

        if masked_diff.numel() == 0:
            loss = torch.tensor(0.0, device=reconstructed.device, requires_grad=True)
        else:
            loss = masked_diff.mean()

        return loss, {"reconstruction_loss": loss.item()}