Spaces:
Running
Running
| import math | |
| from inspect import isfunction | |
| from functools import partial | |
| import matplotlib.pyplot as plt | |
| from tqdm.auto import tqdm | |
| from einops import rearrange | |
| import torch | |
| from torch import nn, einsum | |
| import torch.nn.functional as F | |
| def exists(x): | |
| return x is not None | |
| def default(val, d): | |
| if exists(val): | |
| return val | |
| return d() if isfunction(d) else d | |
| class Residual(nn.Module): | |
| def __init__(self, fn): | |
| super().__init__() | |
| self.fn = fn | |
| def forward(self, x, *args, **kwargs): | |
| return self.fn(x, *args, **kwargs) + x | |
| def Upsample(dim): | |
| return nn.ConvTranspose2d(dim, dim, 4, 2, 1) | |
| def Downsample(dim): | |
| return nn.Conv2d(dim, dim, 4, 2, 1) | |
| class SinusoidalPositionEmbeddings(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, time): | |
| device = time.device | |
| half_dim = self.dim // 2 | |
| embeddings = math.log(10000) / (half_dim - 1) | |
| embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) | |
| embeddings = time[:, None] * embeddings[None, :] | |
| embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) | |
| return embeddings | |
| class Block(nn.Module): | |
| def __init__(self, dim, dim_out, groups=8): | |
| super().__init__() | |
| self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) | |
| self.norm = nn.GroupNorm(groups, dim_out) | |
| self.act = nn.SiLU() | |
| def forward(self, x, scale_shift=None): | |
| x = self.proj(x) | |
| x = self.norm(x) | |
| if exists(scale_shift): | |
| scale, shift = scale_shift | |
| x = x * (scale + 1) + shift | |
| x = self.act(x) | |
| return x | |
| class ResnetBlock(nn.Module): | |
| """https://arxiv.org/abs/1512.03385""" | |
| def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): | |
| super().__init__() | |
| self.mlp = ( | |
| nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) | |
| if exists(time_emb_dim) | |
| else None | |
| ) | |
| self.block1 = Block(dim, dim_out, groups=groups) | |
| self.block2 = Block(dim_out, dim_out, groups=groups) | |
| self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() | |
| def forward(self, x, time_emb=None): | |
| h = self.block1(x) | |
| if exists(self.mlp) and exists(time_emb): | |
| time_emb = self.mlp(time_emb) | |
| h = rearrange(time_emb, "b c -> b c 1 1") + h | |
| h = self.block2(h) | |
| return h + self.res_conv(x) | |
| class ConvNextBlock(nn.Module): | |
| """https://arxiv.org/abs/2201.03545""" | |
| def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True): | |
| super().__init__() | |
| self.mlp = ( | |
| nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim)) | |
| if exists(time_emb_dim) | |
| else None | |
| ) | |
| self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim) | |
| self.net = nn.Sequential( | |
| nn.GroupNorm(1, dim) if norm else nn.Identity(), | |
| nn.Conv2d(dim, dim_out * mult, 3, padding=1), | |
| nn.GELU(), | |
| nn.GroupNorm(1, dim_out * mult), | |
| nn.Conv2d(dim_out * mult, dim_out, 3, padding=1), | |
| ) | |
| self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() | |
| def forward(self, x, time_emb=None): | |
| h = self.ds_conv(x) | |
| if exists(self.mlp) and exists(time_emb): | |
| assert exists(time_emb), "time embedding must be passed in" | |
| condition = self.mlp(time_emb) | |
| h = h + rearrange(condition, "b c -> b c 1 1") | |
| h = self.net(h) | |
| return h + self.res_conv(x) | |
| class Attention(nn.Module): | |
| def __init__(self, dim, heads=4, dim_head=32): | |
| super().__init__() | |
| self.scale = dim_head**-0.5 | |
| self.heads = heads | |
| hidden_dim = dim_head * heads | |
| self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) | |
| self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False) | |
| self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False) | |
| self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False) | |
| self.to_out = nn.Conv2d(hidden_dim, dim, 1) | |
| def forward(self, x): | |
| b, c, h, w = x.shape | |
| qkv = self.to_qkv(x).chunk(3, dim=1) | |
| q, k, v = map( | |
| lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv | |
| ) | |
| q = q * self.scale | |
| sim = einsum("b h d i, b h d j -> b h i j", q, k) | |
| sim = sim - sim.amax(dim=-1, keepdim=True).detach() | |
| attn = sim.softmax(dim=-1) | |
| out = einsum("b h i j, b h d j -> b h i d", attn, v) | |
| out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) | |
| return self.to_out(out) | |
| class LinearCrossAttention(nn.Module): | |
| def __init__(self, dim, heads=4, dim_head=32) -> None: | |
| super().__init__() | |
| self.scale = dim_head**-0.5 | |
| self.heads = heads | |
| hidden_dim = dim_head * heads | |
| self.to_kv = nn.Conv2d(dim, hidden_dim * 2, 1, bias=False) | |
| self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False) | |
| self.out = nn.Conv2d(hidden_dim, dim, 1) | |
| def forward(self, x, cross_attend): | |
| b, c, h, w = x.shape | |
| q = self.to_q(x) | |
| k, v = self.to_kv(cross_attend).chunk(2, dim=1) | |
| q = rearrange(q, "b (h c) x y -> b h c (x y)", h=self.heads) | |
| k = rearrange(k, "b (h c) x y -> b h c (x y)", h=self.heads) | |
| v = rearrange(v, "b (h c) x y -> b h c (x y)", h=self.heads) | |
| q = q * self.scale | |
| sim = einsum("b h d i, b h d j -> b h i j", q, k) | |
| sim = sim - sim.amax(dim=-1, keepdim=True).detach() | |
| attn = sim.softmax(dim=-1) | |
| out = einsum("b h i j, b h d j -> b h i d", attn, v) | |
| out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) | |
| return self.out(out) | |
| class LinearAttention(nn.Module): | |
| def __init__(self, dim, heads=4, dim_head=32): | |
| super().__init__() | |
| self.scale = dim_head**-0.5 | |
| self.heads = heads | |
| hidden_dim = dim_head * heads | |
| self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) | |
| self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False) | |
| self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False) | |
| self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False) | |
| self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim)) | |
| def forward(self, x): | |
| b, c, h, w = x.shape | |
| qkv = self.to_qkv(x).chunk(3, dim=1) | |
| q, k, v = map( | |
| lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv | |
| ) | |
| # calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim | |
| q = q.softmax(dim=-2) | |
| # calculate the softmax with respect to rows of k | |
| k = k.softmax(dim=-1) | |
| # normalize the values in the attention matrix | |
| q = q * self.scale | |
| # dot product of q and v matrices | |
| context = torch.einsum("b h d n, b h e n -> b h d e", k, v) | |
| # dot product of context and q | |
| out = torch.einsum("b h d e, b h d n -> b h e n", context, q) | |
| # rearrange the output to match the pytorch convention | |
| out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) | |
| return self.to_out(out) | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.fn = fn | |
| self.norm = nn.GroupNorm(1, dim) | |
| def forward(self, x, *args, **kwargs): | |
| x = self.norm(x) | |
| return self.fn(x, *args, **kwargs) | |