| import torch |
| import math |
| import torch.nn.functional as F |
| from torch import nn, einsum |
| from inspect import isfunction |
|
|
|
|
| def exists(val): |
| return val is not None |
|
|
| def uniq(arr): |
| return{el: True for el in arr}.keys() |
|
|
|
|
| def default(val, d): |
| if exists(val): |
| return val |
| return d() if isfunction(d) else d |
|
|
|
|
| def max_neg_value(t): |
| return -torch.finfo(t.dtype).max |
|
|
|
|
| def init_(tensor): |
| dim = tensor.shape[-1] |
| std = 1 / math.sqrt(dim) |
| tensor.uniform_(-std, std) |
| return tensor |
|
|
|
|
| |
| class GEGLU(nn.Module): |
| def __init__(self, dim_in, dim_out): |
| super().__init__() |
| self.proj = nn.Linear(dim_in, dim_out * 2) |
|
|
| def forward(self, x): |
| x, gate = self.proj(x).chunk(2, dim=-1) |
| return x * F.gelu(gate) |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, dim, dim_out=None, mult=4, glu=True, dropout=0.): |
| super().__init__() |
| inner_dim = int(dim * mult) |
| dim_out = default(dim_out, dim) |
| project_in = nn.Sequential( |
| nn.Linear(dim, inner_dim), |
| nn.GELU() |
| ) if not glu else GEGLU(dim, inner_dim) |
|
|
| self.net = nn.Sequential( |
| project_in, |
| nn.Dropout(dropout), |
| nn.Linear(inner_dim, dim_out) |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| class SelfAttention(nn.Module): |
| def __init__(self, query_dim, heads=8, dim_head=64, dropout=0.): |
| super().__init__() |
| inner_dim = dim_head * heads |
| self.scale = dim_head ** -0.5 |
| self.heads = heads |
|
|
| self.to_q = nn.Linear(query_dim, inner_dim, bias=False) |
| self.to_k = nn.Linear(query_dim, inner_dim, bias=False) |
| self.to_v = nn.Linear(query_dim, inner_dim, bias=False) |
|
|
| self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) |
|
|
| def forward(self, x): |
| q = self.to_q(x) |
| k = self.to_k(x) |
| v = self.to_v(x) |
|
|
| B, N, HC = q.shape |
| H = self.heads |
| C = HC // H |
|
|
| q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) |
| k = k.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) |
| v = v.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) |
|
|
| sim = torch.einsum('b i c, b j c -> b i j', q, k) * self.scale |
| attn = sim.softmax(dim=-1) |
|
|
| out = torch.einsum('b i j, b j c -> b i c', attn, v) |
| out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) |
|
|
| return self.to_out(out) |
|
|
|
|
|
|
| class Resampler(nn.Module): |
| def __init__(self, query_dim=1024, n_heads=8, d_head=64): |
| super().__init__() |
|
|
| self.attn = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head) |
| self.ff = FeedForward(query_dim, glu=True) |
|
|
| self.norm1 = nn.LayerNorm(query_dim) |
| self.norm2 = nn.LayerNorm(query_dim) |
|
|
| def forward(self, x): |
| x = x + self.attn(self.norm1(x)) |
| x = x + self.ff(self.norm2(x)) |
| return x |