Spaces:
Sleeping
Sleeping
| from functools import reduce | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from .kernel import exported_tdp | |
| import torch.nn.functional as F | |
| from functools import partial | |
| from timm.models.layers import trunc_normal_ | |
| class TimeDependentParameter(nn.Module): | |
| def __init__(self, shape, init_fn): | |
| super().__init__() | |
| self.shape = shape | |
| w = torch.empty(*shape) | |
| init_fn(w) | |
| self.param0 = nn.Parameter(w.clone().detach()) | |
| self.param1 = nn.Parameter(w.clone().detach()) | |
| self.nodecay_weight = nn.Parameter(torch.zeros(*shape)) | |
| self.nodecay_bias = nn.Parameter(torch.zeros(*shape)) | |
| self.curr_weight = None | |
| def forward(self): | |
| weight = self.curr_weight | |
| # self.curr_weight = None | |
| return weight | |
| def __repr__(self): | |
| return f"TimeDependentParameter(shape={self.shape})" | |
| def seed_time(model, log_snr): | |
| assert log_snr.dim() == 1 | |
| if torch.all(log_snr == log_snr[0]): | |
| log_snr = log_snr[0][None] | |
| time_condition = log_snr / 4.0 | |
| tdp_list = [module for module in model.modules() if isinstance(module, TimeDependentParameter)] | |
| for tdp in tdp_list: | |
| tdp.curr_weight = exported_tdp(tdp.param0, tdp.param1, tdp.nodecay_weight + 1, tdp.nodecay_bias, time_condition, custom = False) | |
| class LayerNorm(nn.Module): | |
| def __init__(self, dim, num_groups = 1, eps = 1e-05): | |
| super().__init__() | |
| self.eps = eps | |
| self.dim = dim | |
| self.num_groups = num_groups | |
| self.weight = TimeDependentParameter((dim, ), nn.init.ones_) | |
| self.bias = TimeDependentParameter((dim, ), nn.init.zeros_) | |
| def _forward(self, x): | |
| weight, bias = self.weight(), self.bias() | |
| assert weight.shape[0] == bias.shape[0] | |
| assert x.shape[-1] == self.dim | |
| if weight.shape[0] == 1: | |
| x = F.layer_norm(x, (self.dim, ), weight = weight[0], bias = bias[0], eps = self.eps) | |
| else: | |
| assert x.shape[0] == weight.shape[0] | |
| x = F.layer_norm(x, (self.dim, ), eps = self.eps) | |
| x = torch.addcmul(bias[:, None, :], weight[:, None, :], x) | |
| return x | |
| def forward(self, x): | |
| original_shape = x.shape | |
| batch_size = x.shape[0] | |
| assert self.dim == x.shape[-1] | |
| x = x.reshape(batch_size, -1, self.dim) | |
| x = self._forward(x) | |
| x = x.reshape(*original_shape) | |
| return x | |
| class Linear(nn.Module): | |
| def __init__(self, din, dout, bias = True, weight_init_fn = partial(trunc_normal_, std = 0.02)): | |
| super().__init__() | |
| self.din = din | |
| self.dout = dout | |
| self.weight = TimeDependentParameter((din, dout), weight_init_fn) | |
| if bias: | |
| self.bias = TimeDependentParameter((dout, ), nn.init.zeros_) | |
| else: | |
| self.bias = None | |
| def _forward(self, x): | |
| weight = self.weight() | |
| bias = self.bias() if self.bias is not None else None | |
| # if weight.shape[0] == 1: | |
| # B, L, D = x.shape | |
| # if bias is not None: | |
| # assert weight.shape[0] == bias.shape[0] | |
| # x = torch.addmm(bias, x.reshape(B * L, D), weight[0]) | |
| # else: | |
| # x = torch.matmul(x.reshape(B * L, D), weight[0]) | |
| # x = x.reshape(B, L, -1) | |
| # else: | |
| if bias is not None: | |
| x = torch.baddbmm(bias[:, None, :], x, weight) | |
| else: | |
| x = torch.bmm(x, weight) | |
| return x | |
| def forward(self, x): | |
| original_shape = x.shape | |
| batch_size = x.shape[0] | |
| x = x.reshape(batch_size, -1, self.din) | |
| x = self._forward(x) | |
| x = x.reshape(*(list(original_shape[:-1]) + [self.dout])) | |
| return x | |
| class RMSNorm(nn.Module): | |
| def __init__(self, d, p=-1., eps=1e-8, bias=False): | |
| """ | |
| Root Mean Square Layer Normalization | |
| :param d: model size | |
| :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled) | |
| :param eps: epsilon value, default 1e-8 | |
| :param bias: whether use bias term for RMSNorm, disabled by | |
| default because RMSNorm doesn't enforce re-centering invariance. | |
| """ | |
| super(RMSNorm, self).__init__() | |
| self.eps = eps | |
| self.d = d | |
| self.p = p | |
| self.bias = bias | |
| self.scale = nn.Parameter(torch.ones(d)) | |
| self.register_parameter("scale", self.scale) | |
| if self.bias: | |
| self.offset = nn.Parameter(torch.zeros(d)) | |
| self.register_parameter("offset", self.offset) | |
| def forward(self, x): | |
| if self.p < 0. or self.p > 1.: | |
| norm_x = x.norm(2, dim=-1, keepdim=True) | |
| d_x = self.d | |
| else: | |
| partial_size = int(self.d * self.p) | |
| partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1) | |
| norm_x = partial_x.norm(2, dim=-1, keepdim=True) | |
| d_x = partial_size | |
| rms_x = norm_x * d_x ** (-1. / 2) | |
| x_normed = x / (rms_x + self.eps) | |
| if self.bias: | |
| return self.scale * x_normed + self.offset | |
| return self.scale * x_normed | |
| class TDRMSNorm(nn.Module): | |
| def __init__(self, d, p=-1., eps=1e-8, bias=False): | |
| """ | |
| Root Mean Square Layer Normalization | |
| :param d: model size | |
| :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled) | |
| :param eps: epsilon value, default 1e-8 | |
| :param bias: whether use bias term for RMSNorm, disabled by | |
| default because RMSNorm doesn't enforce re-centering invariance. | |
| """ | |
| super(TDRMSNorm, self).__init__() | |
| self.eps = eps | |
| self.d = d | |
| self.p = p | |
| self.bias = bias | |
| # self.scale = nn.Parameter(torch.ones(d)) | |
| self.scale = TimeDependentParameter((d, ), nn.init.ones_) | |
| # self.register_parameter("scale", self.scale) | |
| if self.bias: | |
| # self.offset = nn.Parameter(torch.zeros(d)) | |
| self.offset = TimeDependentParameter((d, ), nn.init.zeros_) | |
| # self.register_parameter("offset", self.offset) | |
| def forward(self, x): | |
| if self.p < 0. or self.p > 1.: | |
| norm_x = x.norm(2, dim=-1, keepdim=True) | |
| d_x = self.d | |
| else: | |
| partial_size = int(self.d * self.p) | |
| partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1) | |
| norm_x = partial_x.norm(2, dim=-1, keepdim=True) | |
| d_x = partial_size | |
| rms_x = norm_x * d_x ** (-1. / 2) | |
| x_normed = x / (rms_x + self.eps) | |
| _scale = self.scale() | |
| if self.bias: | |
| # return self.scale * x_normed + self.offset | |
| _offset = self.offset() | |
| if _scale.shape[0] == 1: | |
| return _scale[0] * x_normed + _offset[0] | |
| elif x_normed.dim() == 3: | |
| return torch.addcmul(_offset[:, None, :], _scale[:, None, :], x_normed) | |
| elif x_normed.dim() == 4: | |
| return torch.addcmul(_offset[:, None, None, :], _scale[:, None, None, :], x_normed) | |
| else: | |
| raise NotImplementedError | |
| # return self.scale * x_normed | |
| if _scale.shape[0] == 1: | |
| return _scale[0] * x_normed | |
| elif x_normed.dim() == 3: | |
| return _scale[:, None, :] * x_normed | |
| elif x_normed.dim() == 4: | |
| return _scale[:, None, None, :] * x_normed | |
| else: | |
| raise NotImplementedError | |
| def zero_init(layer): | |
| nn.init.zeros_(layer.weight) | |
| if layer.bias is not None: | |
| nn.init.zeros_(layer.bias) | |
| return layer | |
| def rms_norm(x, scale, eps): | |
| dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) | |
| mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) | |
| scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) | |
| return x * scale.to(x.dtype) | |
| class AdaRMSNorm(nn.Module): | |
| def __init__(self, features, cond_features, eps=1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) | |
| def extra_repr(self): | |
| return f"eps={self.eps}," | |
| def forward(self, x, cond): | |
| return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) | |
| class QKNorm(nn.Module): | |
| def __init__(self, n_heads, eps=1e-6, max_scale=100.0): | |
| super().__init__() | |
| self.eps = eps | |
| self.max_scale = math.log(max_scale) | |
| self.scale = nn.Parameter(torch.full((n_heads,), math.log(10.0))) | |
| self.proj_() | |
| def extra_repr(self): | |
| return f"n_heads={self.scale.shape[0]}, eps={self.eps}" | |
| def proj_(self): | |
| """Modify the scale in-place so it doesn't get "stuck" with zero gradient if it's clamped | |
| to the max value.""" | |
| self.scale.clamp_(max=self.max_scale) | |
| def forward(self, x): | |
| self.proj_() | |
| scale = torch.exp(0.5 * self.scale - 0.25 * math.log(x.shape[-1])) | |
| return rms_norm(x, scale[:, None, None], self.eps) |