Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from einops import rearrange, reduce, repeat | |
| from .model_utils import ( | |
| LearnablePositionalEncoding, | |
| Conv_MLP, | |
| AdaLayerNorm, | |
| Transpose, | |
| GELU2, | |
| series_decomp, | |
| ) | |
| class TrendBlock(nn.Module): | |
| """ | |
| Model trend of time series using the polynomial regressor. | |
| """ | |
| def __init__(self, in_dim, out_dim, in_feat, out_feat, act): | |
| super(TrendBlock, self).__init__() | |
| trend_poly = 3 | |
| self.trend = nn.Sequential( | |
| nn.Conv1d( | |
| in_channels=in_dim, out_channels=trend_poly, kernel_size=3, padding=1 | |
| ), | |
| act, | |
| Transpose(shape=(1, 2)), | |
| nn.Conv1d(in_feat, out_feat, 3, stride=1, padding=1), | |
| ) | |
| lin_space = torch.arange(1, out_dim + 1, 1) / (out_dim + 1) | |
| self.poly_space = torch.stack( | |
| [lin_space ** float(p + 1) for p in range(trend_poly)], dim=0 | |
| ) | |
| def forward(self, input): | |
| b, c, h = input.shape | |
| x = self.trend(input).transpose(1, 2) | |
| trend_vals = torch.matmul(x.transpose(1, 2), self.poly_space.to(x.device)) | |
| trend_vals = trend_vals.transpose(1, 2) | |
| return trend_vals | |
| class MovingBlock(nn.Module): | |
| """ | |
| Model trend of time series using the moving average. | |
| """ | |
| def __init__(self, out_dim): | |
| super(MovingBlock, self).__init__() | |
| size = max(min(int(out_dim / 4), 24), 4) | |
| self.decomp = series_decomp(size) | |
| def forward(self, input): | |
| b, c, h = input.shape | |
| x, trend_vals = self.decomp(input) | |
| return x, trend_vals | |
| class FourierLayer(nn.Module): | |
| """ | |
| Model seasonality of time series using the inverse DFT. | |
| """ | |
| def __init__(self, d_model, low_freq=1, factor=1): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.factor = factor | |
| self.low_freq = low_freq | |
| def forward(self, x): | |
| """x: (b, t, d)""" | |
| # x = x.to("cpu") if torch.backends.mps.is_available() else x | |
| b, t, d = x.shape | |
| x_freq = torch.fft.rfft(x, dim=1) | |
| if t % 2 == 0: | |
| x_freq = x_freq[:, self.low_freq : -1] | |
| f = torch.fft.rfftfreq(t)[self.low_freq : -1] | |
| else: | |
| x_freq = x_freq[:, self.low_freq :] | |
| f = torch.fft.rfftfreq(t)[self.low_freq :] | |
| x_freq, index_tuple = self.topk_freq(x_freq) | |
| f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)).to( | |
| x_freq.device | |
| ) | |
| f = rearrange(f[index_tuple], "b f d -> b f () d").to(x_freq.device) | |
| result = self.extrapolate(x_freq, f, t) | |
| return result | |
| # return result.to("mps") if torch.backends.mps.is_available() else result | |
| def extrapolate(self, x_freq, f, t): | |
| x_freq = torch.cat([x_freq, x_freq.conj()], dim=1) | |
| f = torch.cat([f, -f], dim=1) | |
| t = rearrange(torch.arange(t, dtype=torch.float), "t -> () () t ()").to( | |
| x_freq.device | |
| ) | |
| amp = rearrange(x_freq.abs(), "b f d -> b f () d") | |
| phase = rearrange(x_freq.angle(), "b f d -> b f () d") | |
| # x_freq_angle = x_freq.cpu().angle().to(x_freq.device) | |
| # print(x_freq.device, x_freq.shape) | |
| # def angle(x): return torch.atan2(x.imag, x.real) | |
| # print(x_freq.angle().type(), x_freq.angle().device, x_freq.angle().shape) | |
| # print(angle(x_freq).type(), angle(x_freq).device, angle(x_freq).shape) | |
| # phase = rearrange(angle(x_freq).float(), 'b f d -> b f () d') | |
| x_time = amp * torch.cos(2 * math.pi * f * t + phase) | |
| return reduce(x_time, "b f t d -> b t d", "sum") | |
| def topk_freq(self, x_freq): | |
| length = x_freq.shape[1] | |
| top_k = int(self.factor * math.log(length)) | |
| values, indices = torch.topk( | |
| x_freq.abs(), top_k, dim=1, largest=True, sorted=True | |
| ) | |
| mesh_a, mesh_b = torch.meshgrid( | |
| torch.arange(x_freq.size(0)), torch.arange(x_freq.size(2)), indexing="ij" | |
| ) | |
| index_tuple = (mesh_a.unsqueeze(1), indices, mesh_b.unsqueeze(1)) | |
| x_freq = x_freq[index_tuple] | |
| return x_freq, index_tuple | |
| class SeasonBlock(nn.Module): | |
| """ | |
| Model seasonality of time series using the Fourier series. | |
| """ | |
| def __init__(self, in_dim, out_dim, factor=1): | |
| super(SeasonBlock, self).__init__() | |
| season_poly = factor * min(32, int(out_dim // 2)) | |
| self.season = nn.Conv1d( | |
| in_channels=in_dim, out_channels=season_poly, kernel_size=1, padding=0 | |
| ) | |
| fourier_space = torch.arange(0, out_dim, 1) / out_dim | |
| p1, p2 = ( | |
| (season_poly // 2, season_poly // 2) | |
| if season_poly % 2 == 0 | |
| else (season_poly // 2, season_poly // 2 + 1) | |
| ) | |
| s1 = torch.stack( | |
| [torch.cos(2 * np.pi * p * fourier_space) for p in range(1, p1 + 1)], dim=0 | |
| ) | |
| s2 = torch.stack( | |
| [torch.sin(2 * np.pi * p * fourier_space) for p in range(1, p2 + 1)], dim=0 | |
| ) | |
| self.poly_space = torch.cat([s1, s2]) | |
| def forward(self, input): | |
| b, c, h = input.shape | |
| x = self.season(input) | |
| season_vals = torch.matmul(x.transpose(1, 2), self.poly_space.to(x.device)) | |
| season_vals = season_vals.transpose(1, 2) | |
| return season_vals | |
| class FullAttention(nn.Module): | |
| def __init__( | |
| self, | |
| n_embd, # the embed dim | |
| n_head, # the number of heads | |
| attn_pdrop=0.1, # attention dropout prob | |
| resid_pdrop=0.1, # residual attention dropout prob | |
| ): | |
| super().__init__() | |
| assert n_embd % n_head == 0 | |
| # key, query, value projections for all heads | |
| self.key = nn.Linear(n_embd, n_embd) | |
| self.query = nn.Linear(n_embd, n_embd) | |
| self.value = nn.Linear(n_embd, n_embd) | |
| # regularization | |
| self.attn_drop = nn.Dropout(attn_pdrop) | |
| self.resid_drop = nn.Dropout(resid_pdrop) | |
| # output projection | |
| self.proj = nn.Linear(n_embd, n_embd) | |
| self.n_head = n_head | |
| def forward(self, x, mask=None): | |
| B, T, C = x.size() | |
| k = ( | |
| self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) | |
| ) # (B, nh, T, hs) | |
| q = ( | |
| self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) | |
| ) # (B, nh, T, hs) | |
| v = ( | |
| self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) | |
| ) # (B, nh, T, hs) | |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T) | |
| att = F.softmax(att, dim=-1) # (B, nh, T, T) | |
| att = self.attn_drop(att) | |
| y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) | |
| y = ( | |
| y.transpose(1, 2).contiguous().view(B, T, C) | |
| ) # re-assemble all head outputs side by side, (B, T, C) | |
| att = att.mean(dim=1, keepdim=False) # (B, T, T) | |
| # output projection | |
| y = self.resid_drop(self.proj(y)) | |
| return y, att | |
| class CrossAttention(nn.Module): | |
| def __init__( | |
| self, | |
| n_embd, # the embed dim | |
| condition_embd, # condition dim | |
| n_head, # the number of heads | |
| attn_pdrop=0.1, # attention dropout prob | |
| resid_pdrop=0.1, # residual attention dropout prob | |
| ): | |
| super().__init__() | |
| assert n_embd % n_head == 0 | |
| # key, query, value projections for all heads | |
| self.key = nn.Linear(condition_embd, n_embd) | |
| self.query = nn.Linear(n_embd, n_embd) | |
| self.value = nn.Linear(condition_embd, n_embd) | |
| # regularization | |
| self.attn_drop = nn.Dropout(attn_pdrop) | |
| self.resid_drop = nn.Dropout(resid_pdrop) | |
| # output projection | |
| self.proj = nn.Linear(n_embd, n_embd) | |
| self.n_head = n_head | |
| def forward(self, x, encoder_output, mask=None): | |
| B, T, C = x.size() | |
| B, T_E, _ = encoder_output.size() | |
| # calculate query, key, values for all heads in batch and move head forward to be the batch dim | |
| k = ( | |
| self.key(encoder_output) | |
| .view(B, T_E, self.n_head, C // self.n_head) | |
| .transpose(1, 2) | |
| ) # (B, nh, T, hs) | |
| q = ( | |
| self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) | |
| ) # (B, nh, T, hs) | |
| v = ( | |
| self.value(encoder_output) | |
| .view(B, T_E, self.n_head, C // self.n_head) | |
| .transpose(1, 2) | |
| ) # (B, nh, T, hs) | |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T) | |
| att = F.softmax(att, dim=-1) # (B, nh, T, T) | |
| att = self.attn_drop(att) | |
| y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) | |
| y = ( | |
| y.transpose(1, 2).contiguous().view(B, T, C) | |
| ) # re-assemble all head outputs side by side, (B, T, C) | |
| att = att.mean(dim=1, keepdim=False) # (B, T, T) | |
| # output projection | |
| y = self.resid_drop(self.proj(y)) | |
| return y, att | |
| class EncoderBlock(nn.Module): | |
| """an unassuming Transformer block""" | |
| def __init__( | |
| self, | |
| n_embd=1024, | |
| n_head=16, | |
| attn_pdrop=0.1, | |
| resid_pdrop=0.1, | |
| mlp_hidden_times=4, | |
| activate="GELU", | |
| ): | |
| super().__init__() | |
| self.ln1 = AdaLayerNorm(n_embd) | |
| self.ln2 = nn.LayerNorm(n_embd) | |
| self.attn = FullAttention( | |
| n_embd=n_embd, | |
| n_head=n_head, | |
| attn_pdrop=attn_pdrop, | |
| resid_pdrop=resid_pdrop, | |
| ) | |
| assert activate in ["GELU", "GELU2"] | |
| act = nn.GELU() if activate == "GELU" else GELU2() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(n_embd, mlp_hidden_times * n_embd), | |
| act, | |
| nn.Linear(mlp_hidden_times * n_embd, n_embd), | |
| nn.Dropout(resid_pdrop), | |
| ) | |
| def forward(self, x, timestep, mask=None, label_emb=None): | |
| a, att = self.attn(self.ln1(x, timestep, label_emb), mask=mask) | |
| x = x + a | |
| x = x + self.mlp(self.ln2(x)) # only one really use encoder_output | |
| return x, att | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| n_layer=14, | |
| n_embd=1024, | |
| n_head=16, | |
| attn_pdrop=0.0, | |
| resid_pdrop=0.0, | |
| mlp_hidden_times=4, | |
| block_activate="GELU", | |
| ): | |
| super().__init__() | |
| self.blocks = nn.Sequential( | |
| *[ | |
| EncoderBlock( | |
| n_embd=n_embd, | |
| n_head=n_head, | |
| attn_pdrop=attn_pdrop, | |
| resid_pdrop=resid_pdrop, | |
| mlp_hidden_times=mlp_hidden_times, | |
| activate=block_activate, | |
| ) | |
| for _ in range(n_layer) | |
| ] | |
| ) | |
| def forward(self, input, t, padding_masks=None, label_emb=None): | |
| x = input | |
| for block_idx in range(len(self.blocks)): | |
| x, _ = self.blocks[block_idx](x, t, mask=padding_masks, label_emb=label_emb) | |
| return x | |
| class DecoderBlock(nn.Module): | |
| """an unassuming Transformer block""" | |
| def __init__( | |
| self, | |
| n_channel, | |
| n_feat, | |
| n_embd=1024, | |
| n_head=16, | |
| attn_pdrop=0.1, | |
| resid_pdrop=0.1, | |
| mlp_hidden_times=4, | |
| activate="GELU", | |
| condition_dim=1024, | |
| ): | |
| super().__init__() | |
| self.ln1 = AdaLayerNorm(n_embd) | |
| self.ln2 = nn.LayerNorm(n_embd) | |
| self.attn1 = FullAttention( | |
| n_embd=n_embd, | |
| n_head=n_head, | |
| attn_pdrop=attn_pdrop, | |
| resid_pdrop=resid_pdrop, | |
| ) | |
| self.attn2 = CrossAttention( | |
| n_embd=n_embd, | |
| condition_embd=condition_dim, | |
| n_head=n_head, | |
| attn_pdrop=attn_pdrop, | |
| resid_pdrop=resid_pdrop, | |
| ) | |
| self.ln1_1 = AdaLayerNorm(n_embd) | |
| assert activate in ["GELU", "GELU2"] | |
| act = nn.GELU() if activate == "GELU" else GELU2() | |
| self.trend = TrendBlock(n_channel, n_channel, n_embd, n_feat, act=act) | |
| # self.decomp = MovingBlock(n_channel) | |
| self.seasonal = FourierLayer(d_model=n_embd) | |
| # self.seasonal = SeasonBlock(n_channel, n_channel) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(n_embd, mlp_hidden_times * n_embd), | |
| act, | |
| nn.Linear(mlp_hidden_times * n_embd, n_embd), | |
| nn.Dropout(resid_pdrop), | |
| ) | |
| self.proj = nn.Conv1d(n_channel, n_channel * 2, 1) | |
| self.linear = nn.Linear(n_embd, n_feat) | |
| def forward(self, x, encoder_output, timestep, mask=None, label_emb=None): | |
| a, att = self.attn1(self.ln1(x, timestep, label_emb), mask=mask) | |
| x = x + a | |
| a, att = self.attn2(self.ln1_1(x, timestep), encoder_output, mask=mask) | |
| x = x + a | |
| x1, x2 = self.proj(x).chunk(2, dim=1) | |
| trend, season = self.trend(x1), self.seasonal(x2) | |
| x = x + self.mlp(self.ln2(x)) | |
| m = torch.mean(x, dim=1, keepdim=True) | |
| return x - m, self.linear(m), trend, season | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| n_channel, | |
| n_feat, | |
| n_embd=1024, | |
| n_head=16, | |
| n_layer=10, | |
| attn_pdrop=0.1, | |
| resid_pdrop=0.1, | |
| mlp_hidden_times=4, | |
| block_activate="GELU", | |
| condition_dim=512, | |
| ): | |
| super().__init__() | |
| self.d_model = n_embd | |
| self.n_feat = n_feat | |
| self.blocks = nn.Sequential( | |
| *[ | |
| DecoderBlock( | |
| n_feat=n_feat, | |
| n_channel=n_channel, | |
| n_embd=n_embd, | |
| n_head=n_head, | |
| attn_pdrop=attn_pdrop, | |
| resid_pdrop=resid_pdrop, | |
| mlp_hidden_times=mlp_hidden_times, | |
| activate=block_activate, | |
| condition_dim=condition_dim, | |
| ) | |
| for _ in range(n_layer) | |
| ] | |
| ) | |
| def forward(self, x, t, enc, padding_masks=None, label_emb=None): | |
| b, c, _ = x.shape | |
| # att_weights = [] | |
| mean = [] | |
| season = torch.zeros((b, c, self.d_model), device=x.device) | |
| trend = torch.zeros((b, c, self.n_feat), device=x.device) | |
| for block_idx in range(len(self.blocks)): | |
| x, residual_mean, residual_trend, residual_season = self.blocks[block_idx]( | |
| x, enc, t, mask=padding_masks, label_emb=label_emb | |
| ) | |
| season += residual_season | |
| trend += residual_trend | |
| mean.append(residual_mean) | |
| mean = torch.cat(mean, dim=1) | |
| return x, mean, trend, season | |
| class Transformer(nn.Module): | |
| def __init__( | |
| self, | |
| n_feat, | |
| n_channel, | |
| n_layer_enc=5, | |
| n_layer_dec=14, | |
| n_embd=1024, | |
| n_heads=16, | |
| attn_pdrop=0.1, | |
| resid_pdrop=0.1, | |
| mlp_hidden_times=4, | |
| block_activate="GELU", | |
| max_len=2048, | |
| conv_params=None, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| self.emb = Conv_MLP(n_feat, n_embd, resid_pdrop=resid_pdrop) | |
| self.inverse = Conv_MLP(n_embd, n_feat, resid_pdrop=resid_pdrop) | |
| if conv_params is None or conv_params[0] is None: | |
| if n_feat < 32 and n_channel < 64: | |
| kernel_size, padding = 1, 0 | |
| else: | |
| kernel_size, padding = 5, 2 | |
| else: | |
| kernel_size, padding = conv_params | |
| self.combine_s = nn.Conv1d( | |
| n_embd, | |
| n_feat, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=padding, | |
| padding_mode="circular", | |
| bias=False, | |
| ) | |
| self.combine_m = nn.Conv1d( | |
| n_layer_dec, | |
| 1, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| padding_mode="circular", | |
| bias=False, | |
| ) | |
| self.encoder = Encoder( | |
| n_layer_enc, | |
| n_embd, | |
| n_heads, | |
| attn_pdrop, | |
| resid_pdrop, | |
| mlp_hidden_times, | |
| block_activate, | |
| ) | |
| self.pos_enc = LearnablePositionalEncoding( | |
| n_embd, dropout=resid_pdrop, max_len=max_len | |
| ) | |
| self.decoder = Decoder( | |
| n_channel, | |
| n_feat, | |
| n_embd, | |
| n_heads, | |
| n_layer_dec, | |
| attn_pdrop, | |
| resid_pdrop, | |
| mlp_hidden_times, | |
| block_activate, | |
| condition_dim=n_embd, | |
| ) | |
| self.pos_dec = LearnablePositionalEncoding( | |
| n_embd, dropout=resid_pdrop, max_len=max_len | |
| ) | |
| def forward(self, input, t, padding_masks=None, return_res=False): | |
| emb = self.emb(input) | |
| inp_enc = self.pos_enc(emb) | |
| enc_cond = self.encoder(inp_enc, t, padding_masks=padding_masks) | |
| inp_dec = self.pos_dec(emb) | |
| output, mean, trend, season = self.decoder( | |
| inp_dec, t, enc_cond, padding_masks=padding_masks | |
| ) | |
| res = self.inverse(output) | |
| res_m = torch.mean(res, dim=1, keepdim=True) | |
| season_error = ( | |
| self.combine_s(season.transpose(1, 2)).transpose(1, 2) + res - res_m | |
| ) | |
| trend = self.combine_m(mean) + res_m + trend | |
| if return_res: | |
| return ( | |
| trend, | |
| self.combine_s(season.transpose(1, 2)).transpose(1, 2), | |
| res - res_m, | |
| ) | |
| return trend, season_error | |
| if __name__ == "__main__": | |
| pass | |