Spaces:
Running on Zero
Running on Zero
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from refnet.util import default | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def _norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight | |
| def init_(tensor): | |
| dim = tensor.shape[-1] | |
| std = 1 / math.sqrt(dim) | |
| tensor.uniform_(-std, std) | |
| return tensor | |
| # feedforward | |
| class GEGLU(nn.Module): | |
| def __init__(self, dim_in, dim_out): | |
| super().__init__() | |
| self.proj = nn.Linear(dim_in, dim_out * 2) | |
| def forward(self, x): | |
| x, gate = self.proj(x).chunk(2, dim=-1) | |
| return x * F.gelu(gate) | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): | |
| super().__init__() | |
| inner_dim = int(dim * mult) | |
| dim_out = default(dim_out, dim) | |
| project_in = nn.Sequential( | |
| nn.Linear(dim, inner_dim), | |
| nn.GELU() | |
| ) if not glu else GEGLU(dim, inner_dim) | |
| self.net = nn.Sequential( | |
| project_in, | |
| nn.Dropout(dropout), | |
| nn.Linear(inner_dim, dim_out) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| def zero_module(module): | |
| """ | |
| Zero out the parameters of a module and return it. | |
| """ | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| def Normalize(in_channels): | |
| return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) | |
| class Upsample(nn.Module): | |
| """ | |
| An upsampling layer with an optional convolution. | |
| :param channels: channels in the inputs and outputs. | |
| :param use_conv: a bool determining if a convolution is applied. | |
| :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then | |
| upsampling occurs in the inner-two dimensions. | |
| """ | |
| def __init__(self, channels, use_conv, out_channels=None, padding=1): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_conv = use_conv | |
| if use_conv: | |
| self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding) | |
| def forward(self, x): | |
| assert x.shape[1] == self.channels | |
| x = F.interpolate(x, scale_factor=2, mode="nearest") | |
| if self.use_conv: | |
| x = self.conv(x) | |
| return x |