File size: 8,833 Bytes
0dfdc08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

try:
    from .TimesNet import DataEmbedding
except Exception:
    from TimesNet import DataEmbedding


class _BlockConfig:
    def __init__(self, seq_len: int, pred_len: int, d_model: int, d_ff: int, num_kernels: int, top_k: int = 2, num_stations: int = 0):
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_kernels = num_kernels
        self.top_k = top_k
        self.num_stations = num_stations


class Inception_Block_V1(nn.Module):
    def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
        super(Inception_Block_V1, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_kernels = num_kernels
        kernels = []
        for i in range(self.num_kernels):
            kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i))
        self.kernels = nn.ModuleList(kernels)
        if init_weight:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        res_list = []
        for i, kernel in enumerate(self.kernels):
            res_list.append(kernel(x))
        res = torch.stack(res_list, dim=-1).mean(-1)
        return res


def FFT_for_Period(x, k=2):
    # [B, T, C]
    xf = torch.fft.rfft(x, dim=1)
    # find period by amplitudes
    frequency_list = abs(xf).mean(0).mean(-1)
    frequency_list[0] = 0
    _, top_list = torch.topk(frequency_list, k)
    top_list = top_list.detach().cpu().numpy()
    period = x.shape[1] // top_list
    return period, abs(xf).mean(-1)[:, top_list]


class TimesBlockStationCond(nn.Module):
    """TimesBlock with station ID conditioning (one-hot encoded as 1 channel)."""
    def __init__(self, configs):
        super(TimesBlockStationCond, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.k = configs.top_k
        self.num_stations = getattr(configs, 'num_stations', 0)
        
        # Station ID embedding: maps station ID to d_model dimension
        # This provides richer conditioning information than a single scalar
        if self.num_stations > 0:
            self.station_embedding = nn.Embedding(self.num_stations, configs.d_model)
            # Initialize with small random values
            nn.init.normal_(self.station_embedding.weight, mean=0.0, std=0.02)
        
        # Inception blocks
        self.conv = nn.Sequential(
            Inception_Block_V1(configs.d_model, configs.d_ff,
                               num_kernels=configs.num_kernels),
            nn.GELU(),
            Inception_Block_V1(configs.d_ff, configs.d_model,
                               num_kernels=configs.num_kernels)
        )

    def forward(self, x, station_ids: torch.Tensor = None):
        """
        Args:
            x: (B, T, N) input features
            station_ids: (B,) LongTensor of station IDs (0 to num_stations-1)
        """
        B, T, N = x.size()
        period_list, period_weight = FFT_for_Period(x, self.k)

        res = []
        for i in range(self.k):
            period = period_list[i]
            # padding
            if (self.seq_len + self.pred_len) % period != 0:
                length = (((self.seq_len + self.pred_len) // period) + 1) * period
                padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
                out = torch.cat([x, padding], dim=1)
            else:
                length = (self.seq_len + self.pred_len)
                out = x
            
            # reshape to 2D: (B, N, H, W)
            out = out.reshape(B, length // period, period, N).permute(0, 3, 1, 2).contiguous()
            
            # Inject station ID conditioning via embedding addition
            # This provides richer conditioning (d_model dimensions) compared to single scalar
            if station_ids is not None and self.num_stations > 0:
                # Get station embeddings: (B, d_model)
                station_ids_flat = station_ids.view(B)
                station_emb = self.station_embedding(station_ids_flat)  # (B, d_model)
                
                # out shape: (B, d_model, H, W)
                # Expand station embedding to spatial dimensions: (B, d_model, H, W)
                H = out.size(2)
                W = out.size(3)
                station_emb_spatial = station_emb.view(B, N, 1, 1).expand(-1, -1, H, W)
                
                # Add station embedding to features (element-wise addition)
                # This allows the model to learn station-specific feature modifications
                out = out + station_emb_spatial
            
            # 2D conv: from 1d Variation to 2d Variation
            out = self.conv(out)
            
            # reshape back
            out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
            res.append(out[:, :(self.seq_len + self.pred_len), :])
        
        res = torch.stack(res, dim=-1)
        # adaptive aggregation
        period_weight = F.softmax(period_weight, dim=1)
        period_weight = period_weight.unsqueeze(1).unsqueeze(1).repeat(1, T, N, 1)
        res = torch.sum(res * period_weight, -1)
        
        # residual connection
        res = res + x
        return res


class TimesNetPointCloud(nn.Module):
    """TimesNet reconstruction with exposed encode/project methods for point-cloud mixing."""
    def __init__(self, configs):
        super().__init__()
        self.configs = configs
        self.seq_len = configs.seq_len
        self.pred_len = getattr(configs, 'pred_len', 0)
        self.top_k = configs.top_k
        self.d_model = configs.d_model
        self.d_ff = configs.d_ff
        self.num_kernels = configs.num_kernels
        self.e_layers = configs.e_layers
        self.dropout = configs.dropout
        self.c_out = configs.c_out

        self.num_stations = getattr(configs, 'num_stations', 0)

        self.enc_embedding = DataEmbedding(configs.enc_in, self.d_model, configs.embed, configs.freq,
                                           configs.dropout, configs.seq_len)
        self.model = nn.ModuleList([
            TimesBlockStationCond(_BlockConfig(self.seq_len, 0, self.d_model, self.d_ff, 
                                              self.num_kernels, self.top_k, self.num_stations))
            for _ in range(self.e_layers)
        ])
        self.layer = self.e_layers
        self.layer_norm = nn.LayerNorm(self.d_model)
        self.projection = nn.Linear(self.d_model, self.c_out, bias=True)

    def encode_features_for_reconstruction(self, x_enc: torch.Tensor, station_ids: torch.Tensor = None):
        """
        Encode input with optional station ID conditioning.
        
        Args:
            x_enc: (B, T, C) input signal
            station_ids: (B,) LongTensor of station IDs (0 to num_stations-1), optional
        """
        means = x_enc.mean(1, keepdim=True).detach()
        x_norm = x_enc - means
        stdev = torch.sqrt(torch.var(x_norm, dim=1, keepdim=True, unbiased=False) + 1e-5)
        x_norm = x_norm / stdev
        enc_out = self.enc_embedding(x_norm, None)
        for i in range(self.layer):
            enc_out = self.layer_norm(self.model[i](enc_out, station_ids))
        return enc_out, means, stdev

    def project_features_for_reconstruction(self, enc_out: torch.Tensor, means: torch.Tensor, stdev: torch.Tensor):
        dec_out = self.projection(enc_out)
        dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len + self.seq_len, 1))
        dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len + self.seq_len, 1))
        return dec_out

    def anomaly_detection(self, x_enc: torch.Tensor, station_ids: torch.Tensor = None):
        """Full reconstruction pass with optional station ID conditioning."""
        enc_out, means, stdev = self.encode_features_for_reconstruction(x_enc, station_ids)
        return self.project_features_for_reconstruction(enc_out, means, stdev)

    def forward(self, x_enc, station_ids=None, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
        """
        Forward pass compatible with anomaly_detection task.
        
        Args:
            x_enc: (B, T, C) input signal
            station_ids: (B,) LongTensor of station IDs, optional
        """
        return self.anomaly_detection(x_enc, station_ids)