Spaces:
Sleeping
Sleeping
| import math | |
| import scipy | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn, einsum | |
| from functools import partial | |
| from einops import rearrange, reduce | |
| from scipy.fftpack import next_fast_len | |
| def exists(x): | |
| return x is not None | |
| def default(val, d): | |
| if exists(val): | |
| return val | |
| return d() if callable(d) else d | |
| def identity(t, *args, **kwargs): | |
| return t | |
| def extract(a, t, x_shape): | |
| b, *_ = t.shape | |
| out = a.gather(-1, t) | |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
| def Upsample(dim, dim_out=None): | |
| return nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode="nearest"), | |
| nn.Conv1d(dim, default(dim_out, dim), 3, padding=1), | |
| ) | |
| def Downsample(dim, dim_out=None): | |
| return nn.Conv1d(dim, default(dim_out, dim), 4, 2, 1) | |
| # normalization functions | |
| def normalize_to_neg_one_to_one(x): | |
| return x * 2 - 1 | |
| def unnormalize_to_zero_to_one(x): | |
| return (x + 1) * 0.5 | |
| # sinusoidal positional embeds | |
| class SinusoidalPosEmb(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x): | |
| device = x.device | |
| half_dim = self.dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
| emb = x[:, None] * emb[None, :] | |
| emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
| return emb | |
| # learnable positional embeds | |
| class LearnablePositionalEncoding(nn.Module): | |
| def __init__(self, d_model, dropout=0.1, max_len=1024): | |
| super(LearnablePositionalEncoding, self).__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| # Each position gets its own embedding | |
| # Since indices are always 0 ... max_len, we don't have to do a look-up | |
| self.pe = nn.Parameter( | |
| torch.empty(1, max_len, d_model) | |
| ) # requires_grad automatically set to True | |
| nn.init.uniform_(self.pe, -0.02, 0.02) | |
| def forward(self, x): | |
| r"""Inputs of forward function | |
| Args: | |
| x: the sequence fed to the positional encoder model (required). | |
| Shape: | |
| x: [batch size, sequence length, embed dim] | |
| output: [batch size, sequence length, embed dim] | |
| """ | |
| # print(x.shape) | |
| x = x + self.pe | |
| return self.dropout(x) | |
| class moving_avg(nn.Module): | |
| """ | |
| Moving average block to highlight the trend of time series | |
| """ | |
| def __init__(self, kernel_size, stride): | |
| super(moving_avg, self).__init__() | |
| self.kernel_size = kernel_size | |
| self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) | |
| def forward(self, x): | |
| # padding on the both ends of time series | |
| front = x[:, 0:1, :].repeat( | |
| 1, self.kernel_size - 1 - math.floor((self.kernel_size - 1) // 2), 1 | |
| ) | |
| end = x[:, -1:, :].repeat(1, math.floor((self.kernel_size - 1) // 2), 1) | |
| x = torch.cat([front, x, end], dim=1) | |
| x = self.avg(x.permute(0, 2, 1)) | |
| x = x.permute(0, 2, 1) | |
| return x | |
| class series_decomp(nn.Module): | |
| """ | |
| Series decomposition block | |
| """ | |
| def __init__(self, kernel_size): | |
| super(series_decomp, self).__init__() | |
| self.moving_avg = moving_avg(kernel_size, stride=1) | |
| def forward(self, x): | |
| moving_mean = self.moving_avg(x) | |
| res = x - moving_mean | |
| return res, moving_mean | |
| class series_decomp_multi(nn.Module): | |
| """ | |
| Series decomposition block | |
| """ | |
| def __init__(self, kernel_size): | |
| super(series_decomp_multi, self).__init__() | |
| self.moving_avg = [moving_avg(kernel, stride=1) for kernel in kernel_size] | |
| self.layer = torch.nn.Linear(1, len(kernel_size)) | |
| def forward(self, x): | |
| moving_mean = [] | |
| for func in self.moving_avg: | |
| moving_avg = func(x) | |
| moving_mean.append(moving_avg.unsqueeze(-1)) | |
| moving_mean = torch.cat(moving_mean, dim=-1) | |
| moving_mean = torch.sum( | |
| moving_mean * nn.Softmax(-1)(self.layer(x.unsqueeze(-1))), dim=-1 | |
| ) | |
| res = x - moving_mean | |
| return res, moving_mean | |
| class Transpose(nn.Module): | |
| """Wrapper class of torch.transpose() for Sequential module.""" | |
| def __init__(self, shape: tuple): | |
| super(Transpose, self).__init__() | |
| self.shape = shape | |
| def forward(self, x): | |
| return x.transpose(*self.shape) | |
| class Conv_MLP(nn.Module): | |
| def __init__(self, in_dim, out_dim, resid_pdrop=0.0): | |
| super().__init__() | |
| self.sequential = nn.Sequential( | |
| Transpose(shape=(1, 2)), | |
| nn.Conv1d(in_dim, out_dim, 3, stride=1, padding=1), | |
| nn.Dropout(p=resid_pdrop), | |
| ) | |
| def forward(self, x): | |
| return self.sequential(x).transpose(1, 2) | |
| class Transformer_MLP(nn.Module): | |
| def __init__(self, n_embd, mlp_hidden_times, act, resid_pdrop): | |
| super().__init__() | |
| self.sequential = nn.Sequential( | |
| nn.Conv1d( | |
| in_channels=n_embd, | |
| out_channels=int(mlp_hidden_times * n_embd), | |
| kernel_size=1, | |
| padding=0, | |
| ), | |
| act, | |
| nn.Conv1d( | |
| in_channels=int(mlp_hidden_times * n_embd), | |
| out_channels=int(mlp_hidden_times * n_embd), | |
| kernel_size=3, | |
| padding=1, | |
| ), | |
| act, | |
| nn.Conv1d( | |
| in_channels=int(mlp_hidden_times * n_embd), | |
| out_channels=n_embd, | |
| kernel_size=3, | |
| padding=1, | |
| ), | |
| nn.Dropout(p=resid_pdrop), | |
| ) | |
| def forward(self, x): | |
| return self.sequential(x) | |
| class GELU2(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x): | |
| return x * F.sigmoid(1.702 * x) | |
| class AdaLayerNorm(nn.Module): | |
| def __init__(self, n_embd): | |
| super().__init__() | |
| self.emb = SinusoidalPosEmb(n_embd) | |
| self.silu = nn.SiLU() | |
| self.linear = nn.Linear(n_embd, n_embd * 2) | |
| self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False) | |
| def forward(self, x, timestep, label_emb=None): | |
| emb = self.emb(timestep) | |
| if label_emb is not None: | |
| # print(emb.shape, label_emb.shape) | |
| emb = emb + label_emb | |
| emb = self.linear(self.silu(emb)).unsqueeze(1) | |
| scale, shift = torch.chunk(emb, 2, dim=2) | |
| x = self.layernorm(x) * (1 + scale) + shift | |
| return x | |
| class AdaInsNorm(nn.Module): | |
| def __init__(self, n_embd): | |
| super().__init__() | |
| self.emb = SinusoidalPosEmb(n_embd) | |
| self.silu = nn.SiLU() | |
| self.linear = nn.Linear(n_embd, n_embd * 2) | |
| self.instancenorm = nn.InstanceNorm1d(n_embd) | |
| def forward(self, x, timestep, label_emb=None): | |
| emb = self.emb(timestep) | |
| if label_emb is not None: | |
| emb = emb + label_emb | |
| emb = self.linear(self.silu(emb)).unsqueeze(1) | |
| scale, shift = torch.chunk(emb, 2, dim=2) | |
| x = ( | |
| self.instancenorm(x.transpose(-1, -2)).transpose(-1, -2) * (1 + scale) | |
| + shift | |
| ) | |
| return x | |