import torch import torch.nn as nn import torch.nn.functional as F import math class PositionalEmbedding(nn.Module): def __init__(self, d_model, max_len=5000): super(PositionalEmbedding, self).__init__() # Compute the positional encodings once in log space. pe = torch.zeros(max_len, d_model).float() pe.require_grad = False position = torch.arange(0, max_len).float().unsqueeze(1) div_term = ( torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) ).exp() pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer("pe", pe) def forward(self, x): return self.pe[:, : x.size(1)] class TokenEmbedding(nn.Module): def __init__(self, c_in, d_model): super(TokenEmbedding, self).__init__() padding = 1 if torch.__version__ >= "1.5.0" else 2 self.tokenConv = nn.Conv1d( in_channels=c_in, out_channels=d_model, kernel_size=3, padding=padding, padding_mode="circular", ) for m in self.modules(): if isinstance(m, nn.Conv1d): nn.init.kaiming_normal_( m.weight, mode="fan_in", nonlinearity="leaky_relu" ) def forward(self, x): x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) return x class TokenEmbeddingBasic(nn.Module): def __init__(self, c_in, d_model): super(TokenEmbeddingBasic, self).__init__() self.linear = nn.Linear(c_in, d_model) def forward(self, x): x = self.linear(x) return x class FixedEmbedding(nn.Module): def __init__(self, c_in, d_model): super(FixedEmbedding, self).__init__() w = torch.zeros(c_in, d_model).float() w.require_grad = False position = torch.arange(0, c_in).float().unsqueeze(1) div_term = ( torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) ).exp() w[:, 0::2] = torch.sin(position * div_term) w[:, 1::2] = torch.cos(position * div_term) self.emb = nn.Embedding(c_in, d_model) self.emb.weight = nn.Parameter(w, requires_grad=False) def forward(self, x): return self.emb(x).detach() class TemporalEmbedding(nn.Module): def __init__(self, d_model, t_embed="fixed", freq="h"): super(TemporalEmbedding, self).__init__() minute_size = 4 hour_size = 24 weekday_size = 7 day_size = 32 month_size = 13 Embed = FixedEmbedding if t_embed == "fixed" else nn.Embedding if freq == "t": self.minute_embed = Embed(minute_size, d_model) self.hour_embed = Embed(hour_size, d_model) self.weekday_embed = Embed(weekday_size, d_model) self.day_embed = Embed(day_size, d_model) self.month_embed = Embed(month_size, d_model) def forward(self, x): x = x.long() minute_x = ( self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0 ) hour_x = self.hour_embed(x[:, :, 3]) weekday_x = self.weekday_embed(x[:, :, 2]) day_x = self.day_embed(x[:, :, 1]) month_x = self.month_embed(x[:, :, 0]) return hour_x + weekday_x + day_x + month_x + minute_x class TimeFeatureEmbedding(nn.Module): def __init__(self, d_model, t_embed="timeF", freq="h"): super(TimeFeatureEmbedding, self).__init__() freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3} d_inp = freq_map[freq] self.embed = nn.Linear(d_inp, d_model) def forward(self, x): return self.embed(x) class Time2Vec(nn.Module): def __init__(self, time_emb_dim, freq="h"): super(Time2Vec, self).__init__() freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3} time_feat_dim = freq_map[freq] self.output_dim = time_emb_dim self.out_features = time_emb_dim # TODO: Initialize uniform self.linear_periodic = nn.Linear(time_feat_dim, time_emb_dim - 1) self.linear_non_periodic = nn.Linear(time_feat_dim, 1) def forward(self, x): non_periodic = self.linear_non_periodic(x.float()) periodic = torch.sin(self.linear_periodic(x.float())) out = torch.cat([non_periodic, periodic], -1) return out class DataEmbedding(nn.Module): def __init__( self, c_in, d_model, t_embed="fixed", freq="h", dropout_emb=0.01, position_embedding=True, emb_t2v_app_dim=32, tok_emb="default", ): super(DataEmbedding, self).__init__() self.append_time_emb = t_embed == "time2vec_app" # For the temporal embedding if t_embed is not None: assert t_embed in [ "fixed", "learned", "timeF", "time2vec_add", "time2vec_app", ], "Invalid t_embed" if t_embed == "fixed" or t_embed == "learned": self.temporal_embedding = TemporalEmbedding( d_model=d_model, t_embed=t_embed, freq=freq ) elif t_embed == "timeF": self.temporal_embedding = TimeFeatureEmbedding( d_model=d_model, t_embed=t_embed, freq=freq ) elif t_embed == "time2vec_add": # Time2Vec time embedding add elementwise self.temporal_embedding = Time2Vec(time_emb_dim=d_model, freq=freq) elif t_embed == "time2vec_app": # Time2Vec time embedding appended assert ( emb_t2v_app_dim is not None ), "Need to provide the emb_t2v_app_dim argument" assert emb_t2v_app_dim > 0 and emb_t2v_app_dim < d_model self.temporal_embedding = Time2Vec( time_emb_dim=emb_t2v_app_dim, freq=freq ) d_model -= emb_t2v_app_dim else: self.temporal_embedding = lambda _: 0 # For the value embedding if tok_emb == "basic": self.value_embedding = TokenEmbeddingBasic(c_in=c_in, d_model=d_model) elif tok_emb == "raw": self.value_embedding = lambda x: x assert c_in == d_model, "c_in and d_model must be equal for raw embedding" assert ( t_embed != "time2vec_app" ), "time2vec_app not supported for raw embedding" else: self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) self.position_embedding = ( PositionalEmbedding(d_model=d_model) if position_embedding else lambda x: 0 ) self.dropout = nn.Dropout(p=dropout_emb) def forward(self, x, x_mark): if self.append_time_emb: x = self.value_embedding(x) + self.position_embedding(x) x_drop = self.dropout(x) time_emb = self.temporal_embedding(x_mark) return torch.concat([x_drop, time_emb], -1) else: x = ( self.value_embedding(x) + self.position_embedding(x) + self.temporal_embedding(x_mark) ) return self.dropout(x)