oscarzhang's picture
Upload folder using huggingface_hub
6b11dc4 verified
"""
TFT (Temporal Fusion Transformer) 模型实现
简化版本,支持静态特征、已知未来特征和观测时序特征
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class TFTEncoder(nn.Module):
"""TFT编码器"""
def __init__(self, d_model, nhead, num_layers, dim_feedforward, dropout=0.1):
super(TFTEncoder, self).__init__()
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
def forward(self, x, mask=None):
"""
前向传播
参数:
x: 输入 [batch_size, seq_len, d_model]
mask: 注意力mask [batch_size, seq_len]
返回:
output: 编码输出 [batch_size, seq_len, d_model]
"""
return self.encoder(x, src_key_padding_mask=mask)
class TFTDecoder(nn.Module):
"""TFT解码器"""
def __init__(self, d_model, nhead, num_layers, dim_feedforward, dropout=0.1):
super(TFTDecoder, self).__init__()
decoder_layer = nn.TransformerDecoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True
)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
"""
前向传播
参数:
tgt: 目标序列 [batch_size, tgt_len, d_model]
memory: 编码器输出 [batch_size, seq_len, d_model]
tgt_mask: 目标mask
memory_mask: 记忆mask
返回:
output: 解码输出 [batch_size, tgt_len, d_model]
"""
return self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_mask)
class TemporalFusionTransformer(nn.Module):
"""
Temporal Fusion Transformer (简化版)
支持静态特征、已知未来特征和观测时序特征
"""
def __init__(self, num_observed_features, num_static_features, num_known_future_features,
num_output_features=None, hidden_size=128, num_heads=4, num_encoder_layers=3, num_decoder_layers=3,
dim_feedforward=512, dropout=0.1):
"""
初始化TFT
参数:
num_observed_features: 观测时序特征维度(来自Phased LSTM的输出)
num_static_features: 静态特征维度
num_known_future_features: 已知未来特征维度
hidden_size: 隐藏层维度
num_heads: 注意力头数
num_encoder_layers: 编码器层数
num_decoder_layers: 解码器层数
dim_feedforward: 前馈网络维度
dropout: Dropout率
"""
super(TemporalFusionTransformer, self).__init__()
self.hidden_size = hidden_size
self.num_observed_features = num_observed_features
self.num_static_features = num_static_features
self.num_known_future_features = num_known_future_features
self.num_output_features = num_output_features or num_observed_features # 默认输出与观测特征相同
# 特征嵌入
self.observed_embedding = nn.Linear(num_observed_features, hidden_size)
self.static_embedding = nn.Linear(num_static_features, hidden_size)
self.known_future_embedding = nn.Linear(num_known_future_features, hidden_size)
# 位置编码(可选)
self.pos_encoder = PositionalEncoding(hidden_size, dropout)
# 编码器
self.encoder = TFTEncoder(
d_model=hidden_size,
nhead=num_heads,
num_layers=num_encoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout
)
# 解码器
self.decoder = TFTDecoder(
d_model=hidden_size,
nhead=num_heads,
num_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout
)
# 输出层(预测原始特征)
self.output_layer = nn.Linear(hidden_size, self.num_output_features)
# 静态特征融合(将静态特征广播到每个时间步)
self.static_fusion = nn.Linear(hidden_size * 2, hidden_size)
def forward(self, observed_features, static_features, known_future_features, mask=None, debug=False):
"""
前向传播
参数:
observed_features: 观测时序特征 [batch_size, seq_len, num_observed_features]
static_features: 静态特征 [batch_size, num_static_features]
known_future_features: 已知未来特征 [batch_size, pred_len, num_known_future_features]
mask: 注意力mask [batch_size, seq_len, num_features] 或 [batch_size, seq_len]
如果是3D:1表示有效,0表示缺失;如果某个时间步的所有特征都缺失,则该时间步被mask
如果是2D:True表示需要mask的位置(padding),False表示有效位置
debug: 是否启用调试日志
返回:
predictions: 预测值 [batch_size, pred_len, num_observed_features]
"""
batch_size, seq_len, _ = observed_features.size()
pred_len = known_future_features.size(1)
if debug:
print(f" [TFT] 输入检查:")
print(f" observed_features: shape={observed_features.shape}, has_nan={torch.isnan(observed_features).any().item()}")
print(f" static_features: shape={static_features.shape}, has_nan={torch.isnan(static_features).any().item()}")
print(f" known_future_features: shape={known_future_features.shape}, has_nan={torch.isnan(known_future_features).any().item()}")
# 嵌入观测特征
observed_emb = self.observed_embedding(observed_features) # [batch_size, seq_len, hidden_size]
if debug and torch.isnan(observed_emb).any():
print(f" ⚠️ observed_emb包含NaN(在embedding后)")
observed_emb = self.pos_encoder(observed_emb)
if debug and torch.isnan(observed_emb).any():
print(f" ⚠️ observed_emb包含NaN(在pos_encoder后)")
# 嵌入静态特征并广播
static_emb = self.static_embedding(static_features) # [batch_size, hidden_size]
if debug and torch.isnan(static_emb).any():
print(f" ⚠️ static_emb包含NaN(在embedding后)")
static_emb = static_emb.unsqueeze(1).expand(-1, seq_len, -1) # [batch_size, seq_len, hidden_size]
# 融合静态特征和观测特征
encoder_input = torch.cat([observed_emb, static_emb], dim=-1) # [batch_size, seq_len, hidden_size*2]
if debug and torch.isnan(encoder_input).any():
print(f" ⚠️ encoder_input包含NaN(在concat后)")
encoder_input = self.static_fusion(encoder_input) # [batch_size, seq_len, hidden_size]
if debug and torch.isnan(encoder_input).any():
print(f" ⚠️ encoder_input包含NaN(在static_fusion后)")
# 将3D mask转换为2D mask(如果mask是3D的)
# mask: [batch_size, seq_len, num_features] -> [batch_size, seq_len]
# 如果某个时间步的所有特征都缺失,则该时间步被mask
if mask is not None and mask.dim() == 3:
# mask值为1表示有效,0表示缺失
# 如果某个时间步的所有特征都缺失(sum=0),则该时间步应该被mask(True)
mask_2d = (mask.sum(dim=-1) == 0).bool() # [batch_size, seq_len]
# True表示需要mask的位置(padding),False表示有效位置
# 检查:如果整个batch的所有序列都被mask,则设为None(避免Transformer错误)
if mask_2d.all():
mask_2d = None # 整个batch都被mask,不使用mask
elif mask_2d.all(dim=1).any():
# 如果某个样本的所有时间步都被mask,至少保留一个时间步不被mask(避免全mask)
for i in range(mask_2d.size(0)):
if mask_2d[i].all():
mask_2d[i, 0] = False # 至少保留第一个时间步
elif mask is not None and mask.dim() == 2:
# 如果已经是2D的,直接使用(假设True表示需要mask)
mask_2d = mask.bool() if mask.dtype != torch.bool else mask
# 同样检查全mask的情况
if mask_2d.all():
mask_2d = None
elif mask_2d.all(dim=1).any():
for i in range(mask_2d.size(0)):
if mask_2d[i].all():
mask_2d[i, 0] = False
else:
mask_2d = None
# 编码
encoder_output = self.encoder(encoder_input, mask=mask_2d) # [batch_size, seq_len, hidden_size]
if debug and torch.isnan(encoder_output).any():
print(f" ⚠️ encoder_output包含NaN(在encoder后)")
# 嵌入已知未来特征
known_future_emb = self.known_future_embedding(known_future_features) # [batch_size, pred_len, hidden_size]
if debug and torch.isnan(known_future_emb).any():
print(f" ⚠️ known_future_emb包含NaN(在embedding后)")
known_future_emb = self.pos_encoder(known_future_emb)
if debug and torch.isnan(known_future_emb).any():
print(f" ⚠️ known_future_emb包含NaN(在pos_encoder后)")
# 解码(使用相同的mask_2d)
decoder_output = self.decoder(
known_future_emb,
encoder_output,
memory_mask=mask_2d
) # [batch_size, pred_len, hidden_size]
if debug and torch.isnan(decoder_output).any():
print(f" ⚠️ decoder_output包含NaN(在decoder后)")
# 输出预测
predictions = self.output_layer(decoder_output) # [batch_size, pred_len, num_observed_features]
if debug and torch.isnan(predictions).any():
print(f" ⚠️ predictions包含NaN(在output_layer后)")
return predictions
class PositionalEncoding(nn.Module):
"""位置编码"""
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)