Spaces:
Runtime error
Runtime error
| """ | |
| TFT (Temporal Fusion Transformer) 模型实现 | |
| 简化版本,支持静态特征、已知未来特征和观测时序特征 | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| class TFTEncoder(nn.Module): | |
| """TFT编码器""" | |
| def __init__(self, d_model, nhead, num_layers, dim_feedforward, dropout=0.1): | |
| super(TFTEncoder, self).__init__() | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=d_model, | |
| nhead=nhead, | |
| dim_feedforward=dim_feedforward, | |
| dropout=dropout, | |
| batch_first=True | |
| ) | |
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers) | |
| def forward(self, x, mask=None): | |
| """ | |
| 前向传播 | |
| 参数: | |
| x: 输入 [batch_size, seq_len, d_model] | |
| mask: 注意力mask [batch_size, seq_len] | |
| 返回: | |
| output: 编码输出 [batch_size, seq_len, d_model] | |
| """ | |
| return self.encoder(x, src_key_padding_mask=mask) | |
| class TFTDecoder(nn.Module): | |
| """TFT解码器""" | |
| def __init__(self, d_model, nhead, num_layers, dim_feedforward, dropout=0.1): | |
| super(TFTDecoder, self).__init__() | |
| decoder_layer = nn.TransformerDecoderLayer( | |
| d_model=d_model, | |
| nhead=nhead, | |
| dim_feedforward=dim_feedforward, | |
| dropout=dropout, | |
| batch_first=True | |
| ) | |
| self.decoder = nn.TransformerDecoder(decoder_layer, num_layers) | |
| def forward(self, tgt, memory, tgt_mask=None, memory_mask=None): | |
| """ | |
| 前向传播 | |
| 参数: | |
| tgt: 目标序列 [batch_size, tgt_len, d_model] | |
| memory: 编码器输出 [batch_size, seq_len, d_model] | |
| tgt_mask: 目标mask | |
| memory_mask: 记忆mask | |
| 返回: | |
| output: 解码输出 [batch_size, tgt_len, d_model] | |
| """ | |
| return self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_mask) | |
| class TemporalFusionTransformer(nn.Module): | |
| """ | |
| Temporal Fusion Transformer (简化版) | |
| 支持静态特征、已知未来特征和观测时序特征 | |
| """ | |
| def __init__(self, num_observed_features, num_static_features, num_known_future_features, | |
| num_output_features=None, hidden_size=128, num_heads=4, num_encoder_layers=3, num_decoder_layers=3, | |
| dim_feedforward=512, dropout=0.1): | |
| """ | |
| 初始化TFT | |
| 参数: | |
| num_observed_features: 观测时序特征维度(来自Phased LSTM的输出) | |
| num_static_features: 静态特征维度 | |
| num_known_future_features: 已知未来特征维度 | |
| hidden_size: 隐藏层维度 | |
| num_heads: 注意力头数 | |
| num_encoder_layers: 编码器层数 | |
| num_decoder_layers: 解码器层数 | |
| dim_feedforward: 前馈网络维度 | |
| dropout: Dropout率 | |
| """ | |
| super(TemporalFusionTransformer, self).__init__() | |
| self.hidden_size = hidden_size | |
| self.num_observed_features = num_observed_features | |
| self.num_static_features = num_static_features | |
| self.num_known_future_features = num_known_future_features | |
| self.num_output_features = num_output_features or num_observed_features # 默认输出与观测特征相同 | |
| # 特征嵌入 | |
| self.observed_embedding = nn.Linear(num_observed_features, hidden_size) | |
| self.static_embedding = nn.Linear(num_static_features, hidden_size) | |
| self.known_future_embedding = nn.Linear(num_known_future_features, hidden_size) | |
| # 位置编码(可选) | |
| self.pos_encoder = PositionalEncoding(hidden_size, dropout) | |
| # 编码器 | |
| self.encoder = TFTEncoder( | |
| d_model=hidden_size, | |
| nhead=num_heads, | |
| num_layers=num_encoder_layers, | |
| dim_feedforward=dim_feedforward, | |
| dropout=dropout | |
| ) | |
| # 解码器 | |
| self.decoder = TFTDecoder( | |
| d_model=hidden_size, | |
| nhead=num_heads, | |
| num_layers=num_decoder_layers, | |
| dim_feedforward=dim_feedforward, | |
| dropout=dropout | |
| ) | |
| # 输出层(预测原始特征) | |
| self.output_layer = nn.Linear(hidden_size, self.num_output_features) | |
| # 静态特征融合(将静态特征广播到每个时间步) | |
| self.static_fusion = nn.Linear(hidden_size * 2, hidden_size) | |
| def forward(self, observed_features, static_features, known_future_features, mask=None, debug=False): | |
| """ | |
| 前向传播 | |
| 参数: | |
| observed_features: 观测时序特征 [batch_size, seq_len, num_observed_features] | |
| static_features: 静态特征 [batch_size, num_static_features] | |
| known_future_features: 已知未来特征 [batch_size, pred_len, num_known_future_features] | |
| mask: 注意力mask [batch_size, seq_len, num_features] 或 [batch_size, seq_len] | |
| 如果是3D:1表示有效,0表示缺失;如果某个时间步的所有特征都缺失,则该时间步被mask | |
| 如果是2D:True表示需要mask的位置(padding),False表示有效位置 | |
| debug: 是否启用调试日志 | |
| 返回: | |
| predictions: 预测值 [batch_size, pred_len, num_observed_features] | |
| """ | |
| batch_size, seq_len, _ = observed_features.size() | |
| pred_len = known_future_features.size(1) | |
| if debug: | |
| print(f" [TFT] 输入检查:") | |
| print(f" observed_features: shape={observed_features.shape}, has_nan={torch.isnan(observed_features).any().item()}") | |
| print(f" static_features: shape={static_features.shape}, has_nan={torch.isnan(static_features).any().item()}") | |
| print(f" known_future_features: shape={known_future_features.shape}, has_nan={torch.isnan(known_future_features).any().item()}") | |
| # 嵌入观测特征 | |
| observed_emb = self.observed_embedding(observed_features) # [batch_size, seq_len, hidden_size] | |
| if debug and torch.isnan(observed_emb).any(): | |
| print(f" ⚠️ observed_emb包含NaN(在embedding后)") | |
| observed_emb = self.pos_encoder(observed_emb) | |
| if debug and torch.isnan(observed_emb).any(): | |
| print(f" ⚠️ observed_emb包含NaN(在pos_encoder后)") | |
| # 嵌入静态特征并广播 | |
| static_emb = self.static_embedding(static_features) # [batch_size, hidden_size] | |
| if debug and torch.isnan(static_emb).any(): | |
| print(f" ⚠️ static_emb包含NaN(在embedding后)") | |
| static_emb = static_emb.unsqueeze(1).expand(-1, seq_len, -1) # [batch_size, seq_len, hidden_size] | |
| # 融合静态特征和观测特征 | |
| encoder_input = torch.cat([observed_emb, static_emb], dim=-1) # [batch_size, seq_len, hidden_size*2] | |
| if debug and torch.isnan(encoder_input).any(): | |
| print(f" ⚠️ encoder_input包含NaN(在concat后)") | |
| encoder_input = self.static_fusion(encoder_input) # [batch_size, seq_len, hidden_size] | |
| if debug and torch.isnan(encoder_input).any(): | |
| print(f" ⚠️ encoder_input包含NaN(在static_fusion后)") | |
| # 将3D mask转换为2D mask(如果mask是3D的) | |
| # mask: [batch_size, seq_len, num_features] -> [batch_size, seq_len] | |
| # 如果某个时间步的所有特征都缺失,则该时间步被mask | |
| if mask is not None and mask.dim() == 3: | |
| # mask值为1表示有效,0表示缺失 | |
| # 如果某个时间步的所有特征都缺失(sum=0),则该时间步应该被mask(True) | |
| mask_2d = (mask.sum(dim=-1) == 0).bool() # [batch_size, seq_len] | |
| # True表示需要mask的位置(padding),False表示有效位置 | |
| # 检查:如果整个batch的所有序列都被mask,则设为None(避免Transformer错误) | |
| if mask_2d.all(): | |
| mask_2d = None # 整个batch都被mask,不使用mask | |
| elif mask_2d.all(dim=1).any(): | |
| # 如果某个样本的所有时间步都被mask,至少保留一个时间步不被mask(避免全mask) | |
| for i in range(mask_2d.size(0)): | |
| if mask_2d[i].all(): | |
| mask_2d[i, 0] = False # 至少保留第一个时间步 | |
| elif mask is not None and mask.dim() == 2: | |
| # 如果已经是2D的,直接使用(假设True表示需要mask) | |
| mask_2d = mask.bool() if mask.dtype != torch.bool else mask | |
| # 同样检查全mask的情况 | |
| if mask_2d.all(): | |
| mask_2d = None | |
| elif mask_2d.all(dim=1).any(): | |
| for i in range(mask_2d.size(0)): | |
| if mask_2d[i].all(): | |
| mask_2d[i, 0] = False | |
| else: | |
| mask_2d = None | |
| # 编码 | |
| encoder_output = self.encoder(encoder_input, mask=mask_2d) # [batch_size, seq_len, hidden_size] | |
| if debug and torch.isnan(encoder_output).any(): | |
| print(f" ⚠️ encoder_output包含NaN(在encoder后)") | |
| # 嵌入已知未来特征 | |
| known_future_emb = self.known_future_embedding(known_future_features) # [batch_size, pred_len, hidden_size] | |
| if debug and torch.isnan(known_future_emb).any(): | |
| print(f" ⚠️ known_future_emb包含NaN(在embedding后)") | |
| known_future_emb = self.pos_encoder(known_future_emb) | |
| if debug and torch.isnan(known_future_emb).any(): | |
| print(f" ⚠️ known_future_emb包含NaN(在pos_encoder后)") | |
| # 解码(使用相同的mask_2d) | |
| decoder_output = self.decoder( | |
| known_future_emb, | |
| encoder_output, | |
| memory_mask=mask_2d | |
| ) # [batch_size, pred_len, hidden_size] | |
| if debug and torch.isnan(decoder_output).any(): | |
| print(f" ⚠️ decoder_output包含NaN(在decoder后)") | |
| # 输出预测 | |
| predictions = self.output_layer(decoder_output) # [batch_size, pred_len, num_observed_features] | |
| if debug and torch.isnan(predictions).any(): | |
| print(f" ⚠️ predictions包含NaN(在output_layer后)") | |
| return predictions | |
| class PositionalEncoding(nn.Module): | |
| """位置编码""" | |
| def __init__(self, d_model, dropout=0.1, max_len=5000): | |
| super(PositionalEncoding, self).__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| position = torch.arange(max_len).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) | |
| pe = torch.zeros(max_len, d_model) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0) | |
| self.register_buffer('pe', pe) | |
| def forward(self, x): | |
| x = x + self.pe[:, :x.size(1), :] | |
| return self.dropout(x) | |