| from transformers import PreTrainedModel |
| from configuration_LightGTS import LightGTSConfig |
| from ts_generation_mixin import TSGenerationMixin |
| import torch |
| from torch import nn |
| from torch import Tensor |
| from typing import Callable, Optional |
| import math |
| import torch.nn.functional as F |
| import numpy as np |
|
|
|
|
| class LightGTSPreTrainedModel(PreTrainedModel): |
| config_class = LightGTSConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["TSTEncoderLayer"] |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn_2 = True |
| _supports_sdpa = False |
| _supports_cache_class = True |
|
|
|
|
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, torch.nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, torch.nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
|
|
| class LightGTSForPrediction(LightGTSPreTrainedModel, TSGenerationMixin): |
| def __init__(self, config: LightGTSConfig): |
| super().__init__(config) |
| self.config = config |
| self.model = LightGTS(c_in=config.c_in, |
| target_dim=config.target_dim, |
| patch_len=config.patch_len, |
| stride=config.stride, |
| num_patch=config.num_patch, |
| e_layers=config.e_layers, |
| d_layers=config.d_layers, |
| n_heads=config.n_heads, |
| d_model=config.d_model, |
| shared_embedding=True, |
| d_ff=config.d_ff, |
| dropout=config.dropout, |
| attn_dropout=config.attn_dropout, |
| head_dropout=config.head_dropout, |
| act='relu', |
| head_type=config.head_type, |
| res_attention=False, |
| learn_pe=False |
| ) |
| |
| def forward(self, input, labels=None, patch_len=None, stride=None, target_dim=None): |
| |
| self.config.patch_len = patch_len |
| self.config.stride = stride |
| self.config.target_dim = target_dim |
|
|
| |
| batch_size,seq_len,n_vars = input.shape |
| num_patch = (max(seq_len, self.config.patch_len)-self.config.patch_len) // self.config.stride + 1 |
| self.config.num_patch = num_patch |
| outputs = input.view(batch_size, num_patch, self.config.patch_len, n_vars) |
| outputs = outputs.transpose(2, 3) |
| outputs = self.model(outputs, target_dim=self.config.target_dim, patch_len=self.config.patch_len, stride=self.config.stride) |
| |
| |
| loss = None |
| if labels is not None: |
| |
| if outputs.shape != labels.shape: |
| |
| outputs = outputs.view(labels.shape) |
| loss = self.loss_fn(outputs, labels) |
| |
| |
| return {"prediction": outputs, "loss": loss} |
| |
| class LightGTSForFinetune(LightGTSPreTrainedModel, TSGenerationMixin): |
| def __init__(self, config: LightGTSConfig): |
| super().__init__(config) |
| self.config = config |
| self.model = LightGTS(c_in=config.c_in, |
| target_dim=config.target_dim, |
| patch_len=config.patch_len, |
| stride=config.stride, |
| num_patch=config.num_patch, |
| e_layers=config.e_layers, |
| d_layers=config.d_layers, |
| n_heads=config.n_heads, |
| d_model=config.d_model, |
| shared_embedding=True, |
| d_ff=config.d_ff, |
| dropout=config.dropout, |
| attn_dropout=config.attn_dropout, |
| head_dropout=config.head_dropout, |
| act='relu', |
| head_type=config.head_type, |
| res_attention=False, |
| learn_pe=False |
| ) |
| |
| def forward(self, input, labels=None, patch_len=None, stride=None, target_dim=None): |
| |
| if patch_len is not None: |
| self.config.patch_len = patch_len |
| if stride is not None: |
| self.config.stride = stride |
| if target_dim is not None: |
| self.config.target_dim = target_dim |
|
|
| |
| batch_size,seq_len,n_vars = input.shape |
| num_patch = (max(seq_len, self.config.patch_len)-self.config.patch_len) // self.config.stride + 1 |
| self.config.num_patch = num_patch |
| outputs = input.view(batch_size, num_patch, self.config.patch_len, n_vars) |
| outputs = outputs.transpose(2, 3) |
| outputs = self.model(outputs, target_dim=self.config.target_dim, patch_len=self.config.patch_len, stride=self.config.stride) |
| |
| |
| loss = None |
| if labels is not None: |
| |
| if outputs.shape != labels.shape: |
| |
| outputs = outputs.view(labels.shape) |
| loss = self.loss_fn(outputs, labels) |
| |
| |
| return {"prediction": outputs, "loss": loss} |
|
|
|
|
|
|
| class LightGTS(nn.Module): |
| """ |
| Output dimension: |
| [bs x target_dim x nvars] for prediction |
| [bs x target_dim] for regression |
| [bs x target_dim] for classification |
| [bs x num_patch x n_vars x patch_len] for pretrain |
| """ |
| def __init__(self, c_in:int, target_dim:int, patch_len:int, stride:int, num_patch:int, mask_mode:str = 'patch',mask_nums:int = 3, |
| e_layers:int=3, d_layers:int=3, d_model=128, n_heads=16, shared_embedding=True, d_ff:int=256, |
| norm:str='BatchNorm', attn_dropout:float=0.4, dropout:float=0., act:str="gelu", |
| res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False, |
| pe:str='sincos', learn_pe:bool=False, head_dropout = 0, |
| head_type = "prediction", individual = False, |
| y_range:Optional[tuple]=None, verbose:bool=False, **kwargs): |
|
|
| super().__init__() |
| assert head_type in ['pretrain', 'prediction', 'regression', 'classification'], 'head type should be either pretrain, prediction, or regression' |
|
|
| |
| self.num_patch = num_patch |
| self.target_dim=target_dim |
| self.out_patch_num = math.ceil(target_dim / patch_len) |
| self.target_patch_len = 48 |
| |
| self.embedding = nn.Linear(self.target_patch_len, d_model) |
| |
| self.cls_embedding = nn.Parameter(torch.randn(1, 1, 1, d_model),requires_grad=True) |
| |
|
|
| |
| |
| |
|
|
| |
| self.encoder = TSTEncoder(d_model, n_heads, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout, |
| pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=e_layers, |
| store_attn=store_attn) |
| |
| |
| self.decoder = Decoder(d_layers, patch_len=patch_len, d_model=d_model, n_heads=n_heads, d_ff=d_ff,attn_dropout= attn_dropout, dropout=dropout) |
|
|
| |
| self.n_vars = c_in |
| self.head_type = head_type |
| self.mask_mode = mask_mode |
| self.mask_nums = mask_nums |
| self.d_model = d_model |
| self.patch_len = patch_len |
| |
|
|
|
|
|
|
| if head_type == "pretrain": |
| self.head = PretrainHead(d_model, patch_len, head_dropout) |
| elif head_type == "prediction": |
| self.head = decoder_PredictHead(d_model, self.patch_len, self.target_patch_len, head_dropout) |
| |
| def get_dynamic_weights(self, n_preds, decay_rate=0.5): |
| """ |
| Generate dynamic weights for the replicated tokens using an exponential decay scheme. |
| |
| Args: |
| - n_preds (int): Number of predictions to generate weights for. |
| - decay_rate (float): The base of the exponential decay. Lower values decay faster (default: 0.9). |
| |
| Returns: |
| - torch.Tensor: A tensor of weights with exponential decay. |
| """ |
| |
| weights = decay_rate ** torch.arange(n_preds) |
| return weights |
|
|
| def decoder_predict(self, bs, n_vars, dec_cross): |
| """ |
| dec_cross: tensor [bs x n_vars x num_patch x d_model] |
| """ |
| |
| |
| |
| |
| dec_in = dec_cross[:,:,-1,:].unsqueeze(2).expand(-1,-1,self.out_patch_num,-1) |
| weights = self.get_dynamic_weights(self.out_patch_num).to(dec_in.device) |
| dec_in = dec_in * weights.unsqueeze(0).unsqueeze(0).unsqueeze(-1) |
| |
| |
| |
| |
| |
| decoder_output = self.decoder(dec_in, dec_cross) |
| decoder_output = decoder_output.transpose(2,3) |
|
|
| return decoder_output |
|
|
|
|
| def forward(self, z, target_dim=None, patch_len=None, stride=None): |
| """ |
| z: tensor [bs x num_patch x n_vars x patch_len] |
| """ |
|
|
| if target_dim is not None: |
| self.target_dim = target_dim |
| if patch_len is not None: |
| self.patch_len = patch_len |
| if stride is not None: |
| self.stride = stride |
| self.out_patch_num = math.ceil(self.target_dim / self.patch_len) |
|
|
| bs, num_patch, n_vars, patch_len = z.shape |
| |
| cls_tokens = self.cls_embedding.expand(bs, n_vars, -1, -1) |
|
|
| embedding = nn.Linear(patch_len, self.d_model, bias=False) |
| embedding.weight.data = resample_patchemb(old=self.embedding.weight.data, new_patch_len=self.patch_len) |
|
|
| z = embedding(z).permute(0,2,1,3) |
| z = torch.cat((cls_tokens, z), dim=2) |
| |
|
|
| |
| z = torch.reshape(z, (-1, 1 + num_patch, self.d_model)) |
| z = self.encoder(z) |
| z = torch.reshape(z, (-1, n_vars, 1 + num_patch, self.d_model)) |
|
|
| |
| z = self.decoder_predict(bs, n_vars, z[:,:,:,:]) |
| |
| |
| z = self.head(z[:,:,:,:], self.patch_len) |
| z = z[:,:self.target_dim, :] |
|
|
|
|
| |
| |
| |
| |
| return z |
| |
| class TSTEncoder(nn.Module): |
| def __init__(self, d_model, n_heads, d_ff=None, |
| norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu', |
| res_attention=False, n_layers=1, pre_norm=False, store_attn=False): |
| super().__init__() |
|
|
| self.layers = nn.ModuleList([TSTEncoderLayer(d_model, n_heads=n_heads, d_ff=d_ff, norm=norm, |
| attn_dropout=attn_dropout, dropout=dropout, |
| activation=activation, res_attention=res_attention, |
| pre_norm=pre_norm, store_attn=store_attn) for i in range(n_layers)]) |
| self.res_attention = res_attention |
|
|
| def forward(self, src:Tensor): |
| """ |
| src: tensor [bs x q_len x d_model] |
| """ |
| output = src |
| scores = None |
| if self.res_attention: |
| for mod in self.layers: output, scores = mod(output, prev=scores) |
| return output |
| else: |
| for mod in self.layers: output = mod(output) |
| return output |
| |
| class TSTEncoderLayer(nn.Module): |
| def __init__(self, d_model, n_heads, d_ff=256, store_attn=False, |
| norm='LayerNorm', attn_dropout=0, dropout=0., bias=True, |
| activation="gelu", res_attention=False, pre_norm=False): |
| super().__init__() |
| assert not d_model%n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})" |
| d_k = d_model // n_heads |
| d_v = d_model // n_heads |
|
|
| |
| self.res_attention = res_attention |
| self.self_attn = MultiheadAttention(d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=res_attention) |
|
|
| |
| self.dropout_attn = nn.Dropout(dropout) |
| if "batch" in norm.lower(): |
| self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2)) |
| else: |
| self.norm_attn = nn.LayerNorm(d_model) |
|
|
| |
| self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias), |
| get_activation_fn(activation), |
| nn.Dropout(dropout), |
| nn.Linear(d_ff, d_model, bias=bias)) |
|
|
| |
| self.dropout_ffn = nn.Dropout(dropout) |
| if "batch" in norm.lower(): |
| self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2)) |
| else: |
| self.norm_ffn = nn.LayerNorm(d_model) |
|
|
| self.pre_norm = pre_norm |
| self.store_attn = store_attn |
|
|
| |
| |
|
|
|
|
| def forward(self, src:Tensor, prev:Optional[Tensor]=None): |
| """ |
| src: tensor [bs x q_len x d_model] |
| """ |
| |
| if self.pre_norm: |
| src = self.norm_attn(src) |
| |
| if self.res_attention: |
| src2, attn, scores = self.self_attn(src, src, src, prev) |
| else: |
| |
| |
| src2, attn = self.self_attn(src, src, src) |
| if self.store_attn: |
| self.attn = attn |
| |
| |
| |
|
|
| |
|
|
|
|
| |
| src = src + self.dropout_attn(src2) |
| if not self.pre_norm: |
| src = self.norm_attn(src) |
|
|
| |
| if self.pre_norm: |
| src = self.norm_ffn(src) |
| |
| src2 = self.ff(src) |
| |
| src = src + self.dropout_ffn(src2) |
| if not self.pre_norm: |
| src = self.norm_ffn(src) |
|
|
| if self.res_attention: |
| return src, scores |
| else: |
| return src |
| |
|
|
| class Decoder(nn.Module): |
| def __init__(self, d_layers, patch_len, d_model, n_heads, d_ff=None, attn_dropout=0.2, dropout=0.1): |
| super(Decoder, self).__init__() |
|
|
| self.decoder_layers = nn.ModuleList() |
| for i in range(d_layers): |
| self.decoder_layers.append(DecoderLayer(patch_len, d_model, n_heads, d_ff, attn_dropout, dropout)) |
|
|
| def forward(self, x, cross): |
| output = x |
| for layer in self.decoder_layers: |
| output = layer(output, cross) |
| return output |
| |
|
|
| class DecoderLayer(nn.Module): |
| def __init__(self, patch_len, d_model, n_heads, d_ff=None, attn_dropout = 0.2, dropout=0.5, norm="BatchNorm"): |
| super(DecoderLayer, self).__init__() |
| self.self_attention = MultiheadAttention(d_model, n_heads, res_attention=False, attn_dropout=attn_dropout) |
| self.cross_attention = MultiheadAttention(d_model, n_heads, attn_dropout=attn_dropout, rope_type=True) |
| |
| |
| if 'batch' in norm.lower(): |
| self.norm1 = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2)) |
| self.norm2 = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2)) |
| self.norm3 = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2)) |
| else: |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.norm3 = nn.LayerNorm(d_model) |
|
|
|
|
| self.dropout = nn.Dropout(dropout) |
|
|
| self.MLP1 = CMlp(in_features = d_model, hidden_features = d_ff, out_features = d_model, drop=dropout) |
|
|
|
|
|
|
| def forward(self, x, cross): |
| batch, n_vars, num_patch, d_model = x.shape |
| x = x.reshape(batch*n_vars, num_patch, d_model) |
|
|
| |
| |
| |
|
|
| cross = cross.reshape(batch*n_vars, -1, d_model) |
|
|
| attention_mask = causal_attention_mask(num_patch).to(x.device) |
| x_attn , _= self.self_attention(x, attn_mask=attention_mask) |
| x_attn = self.norm1(x_attn) + x |
| |
| x_cross , _ = self.cross_attention(x_attn, cross, cross) |
| x_cross = self.dropout(self.norm2(x_cross)) + x_attn |
|
|
| x_ff = self.MLP1(x_cross) |
| x_ff = self.norm3(x_ff) + x_cross |
|
|
| x_ff = x_ff.reshape(batch, n_vars, num_patch, d_model) |
|
|
| return x_ff |
| |
| def causal_attention_mask(seq_length): |
| """ |
| 创建一个因果注意力掩码。掩码中的每个位置 (i, j) |
| 表示在计算第i个位置的attention时, 第j个位置是否可以被看见。 |
| 如果j <= i, 这个位置被设为1(可见), 否则设为0(不可见)。 |
| |
| Args: |
| seq_length (int): 序列的长度 |
| |
| Returns: |
| torch.Tensor: 因果注意力掩码,大小为 (seq_length, seq_length) |
| """ |
| mask = torch.triu(torch.ones(seq_length, seq_length) * float('-inf'), diagonal=1) |
| return mask |
|
|
| class CMlp(nn.Module): |
| def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = nn.Conv1d(in_features, hidden_features, 1) |
| self.act = act_layer() |
| self.fc2 = nn.Conv1d(hidden_features, out_features, 1) |
| self.drop = nn.Dropout(drop) |
|
|
| def forward(self, x): |
| x = x.permute(0,2,1) |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| x = x.permute(0,2,1) |
| return x |
| |
| class Transpose(nn.Module): |
| def __init__(self, *dims, contiguous=False): |
| super().__init__() |
| self.dims, self.contiguous = dims, contiguous |
| def forward(self, x): |
| if self.contiguous: return x.transpose(*self.dims).contiguous() |
| else: return x.transpose(*self.dims) |
|
|
|
|
| class MultiheadAttention(nn.Module): |
| def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False, rope_type=False): |
| """Multi Head Attention Layer |
| Input shape: |
| Q: [batch_size (bs) x max_q_len x d_model] |
| K, V: [batch_size (bs) x q_len x d_model] |
| mask: [q_len x q_len] |
| """ |
| super().__init__() |
| d_k = d_model // n_heads if d_k is None else d_k |
| d_v = d_model // n_heads if d_v is None else d_v |
|
|
| self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v |
|
|
| self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias) |
| self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias) |
| self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias) |
|
|
| |
| self.res_attention = res_attention |
| self.sdp_attn = ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention, lsa=lsa, rope_type=rope_type) |
|
|
| |
| self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout)) |
|
|
|
|
|
|
|
|
| def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None, |
| key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None): |
|
|
| bs = Q.size(0) |
| if K is None: K = Q |
| if V is None: V = Q |
|
|
| |
| q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2) |
| k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1) |
| v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2) |
|
|
| |
| if self.res_attention: |
| output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask) |
| else: |
| output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask) |
| |
|
|
| |
| output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) |
| output = self.to_out(output) |
|
|
| if self.res_attention: return output, attn_weights, attn_scores |
| else: return output, attn_weights |
|
|
| class ScaledDotProductAttention(nn.Module): |
| r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer |
| (Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets |
| by Lee et al, 2021)""" |
|
|
| def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False, rope_type=False): |
| super().__init__() |
| self.attn_dropout = nn.Dropout(attn_dropout) |
| self.res_attention = res_attention |
| head_dim = d_model // n_heads |
| self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa) |
| self.lsa = lsa |
| self.rope_type = rope_type |
|
|
| def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None): |
| ''' |
| Input shape: |
| q : [bs x n_heads x max_q_len x d_k] |
| k : [bs x n_heads x d_k x seq_len] |
| v : [bs x n_heads x seq_len x d_v] |
| prev : [bs x n_heads x q_len x seq_len] |
| key_padding_mask: [bs x seq_len] |
| attn_mask : [1 x seq_len x seq_len] |
| Output shape: |
| output: [bs x n_heads x q_len x d_v] |
| attn : [bs x n_heads x q_len x seq_len] |
| scores : [bs x n_heads x q_len x seq_len] |
| ''' |
| |
| if self.rope_type: |
| q, k = RoPE_decoder(q, k.permute(0,1,3,2)) |
| else: |
| q, k = RoPE(q, k.permute(0,1,3,2)) |
| k = k.permute(0,1,3,2) |
|
|
| |
| attn_scores = torch.matmul(q, k) * self.scale |
|
|
| |
| if prev is not None: attn_scores = attn_scores + prev |
|
|
| |
| if attn_mask is not None: |
| if attn_mask.dtype == torch.bool: |
| attn_scores.masked_fill_(attn_mask, -np.inf) |
| else: |
| attn_scores += attn_mask |
|
|
| |
| if key_padding_mask is not None: |
| attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf) |
|
|
| |
| attn_weights = F.softmax(attn_scores, dim=-1) |
| attn_weights = self.attn_dropout(attn_weights) |
|
|
| |
| output = torch.matmul(attn_weights, v) |
|
|
| if self.res_attention: return output, attn_weights, attn_scores |
| else: return output, attn_weights |
|
|
| def RoPE(q, k): |
| |
| batch_size = q.shape[0] |
| nums_head = q.shape[1] |
| max_len = q.shape[2] |
| output_dim = q.shape[-1] |
|
|
| |
| pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device, factor=1) |
|
|
| |
| |
| cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) |
| sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) |
|
|
| |
| q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1) |
| q2 = q2.reshape(q.shape) |
|
|
|
|
| |
| q = q * cos_pos + q2 * sin_pos |
|
|
| k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1) |
| k2 = k2.reshape(k.shape) |
| |
| k = k * cos_pos + k2 * sin_pos |
|
|
| return q, k |
|
|
|
|
| def RoPE_decoder(q, k): |
| |
| batch_size = q.shape[0] |
| nums_head = q.shape[1] |
| q_max_len = q.shape[2] |
| k_max_len = k.shape[2] |
| output_dim = q.shape[-1] |
|
|
| |
| pos_emb = sinusoidal_position_embedding(batch_size, nums_head, k_max_len + q_max_len, output_dim, q.device, factor=1) |
|
|
|
|
| |
| |
| cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) |
| sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) |
|
|
| |
| q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1) |
| q2 = q2.reshape(q.shape) |
|
|
|
|
| |
| q = q * cos_pos[:,:,-q_max_len:,:] + q2 * sin_pos[:,:,-q_max_len:,:] |
|
|
|
|
| k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1) |
| k2 = k2.reshape(k.shape) |
| |
| k = k * cos_pos[:,:,:k_max_len,:] + k2 * sin_pos[:,:,:k_max_len,:] |
| return q, k |
|
|
| def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device, factor=1.0): |
| |
| position = torch.arange(0, max_len * factor, 1 / factor, dtype=torch.float).unsqueeze(-1) |
| |
| ids = torch.arange(0, output_dim // 2, dtype=torch.float) |
| theta = torch.pow(10000, -2 * ids / output_dim) |
|
|
| |
| embeddings = position * theta |
|
|
| |
| embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) |
|
|
| |
| embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape)))) |
|
|
| |
| embeddings = torch.reshape(embeddings, (batch_size, nums_head, -1, output_dim)) |
| embeddings = embeddings.to(device) |
|
|
| |
| if factor > 1.0: |
| interpolation_indices = torch.linspace(0, embeddings.shape[2] - 1, max_len).long() |
| embeddings = embeddings[:, :, interpolation_indices, :] |
|
|
| return embeddings |
|
|
| class PretrainHead(nn.Module): |
| def __init__(self, d_model, patch_len, dropout): |
| super().__init__() |
| self.dropout = nn.Dropout(dropout) |
| self.linear = nn.Linear(d_model, patch_len) |
|
|
| def forward(self, x): |
| """ |
| x: tensor [bs x nvars x d_model x num_patch] |
| output: tensor [bs x nvars x num_patch x patch_len] |
| """ |
|
|
| x = x.transpose(2,3) |
| x = self.linear( self.dropout(x) ) |
| x = x.permute(0,2,1,3) |
| return x |
|
|
|
|
| class decoder_PredictHead(nn.Module): |
| def __init__(self, d_model, patch_len, target_patch_len, dropout): |
| super().__init__() |
| self.dropout = nn.Dropout(dropout) |
| self.linear = nn.Linear(d_model, target_patch_len) |
| self.d_model = d_model |
|
|
| def forward(self, x, patch_len): |
| """ |
| x: tensor [bs x nvars x d_model x num_patch] |
| output: tensor [bs x nvars x num_patch x patch_len] |
| """ |
| Linear = nn.Linear(self.d_model, patch_len, bias=False) |
| Linear.weight.data = resample_patchemb(old=self.linear.weight.data.T, new_patch_len=patch_len).T |
|
|
| x = x.transpose(2,3) |
| x = Linear( self.dropout(x) ) |
| x = x.permute(0,2,3,1) |
| return x.reshape(x.shape[0],-1,x.shape[3]) |
| |
| def resample_patchemb(old: torch.Tensor, new_patch_len: int): |
|
|
| assert old.dim() == 2, "输入张量应为2D (d_model, patch_size)" |
| if old.size(1) == new_patch_len: |
| return old |
|
|
| old = old.T |
| old_shape = old.size(0) |
| factor = new_patch_len/old_shape |
| |
| |
| def resize(x_tensor, new_shape): |
| return F.interpolate(x_tensor.unsqueeze(0), size=new_shape, mode='linear').squeeze(0) |
|
|
| |
| basis_vectors = torch.eye(old_shape, dtype=torch.float32, device=old.device) |
| resize_mat = resize(basis_vectors, new_patch_len).T |
| |
| resize_mat_pinv = torch.linalg.pinv(resize_mat.T) |
|
|
| |
| resampled_kernels = resize_mat_pinv @ old * math.sqrt(factor) |
|
|
| return resampled_kernels.T |
|
|
|
|
| def get_activation_fn(activation): |
| if callable(activation): return activation() |
| elif activation.lower() == "relu": return nn.ReLU() |
| elif activation.lower() == "gelu": return nn.GELU() |
| raise ValueError(f'{activation} is not available. You can use "relu", "gelu", or a callable') |