| | import torch
|
| | from torch import nn, einsum
|
| | import torch.nn.functional as F
|
| | from einops import rearrange
|
| |
|
| |
|
| |
|
| | def exists(val):
|
| | return val is not None
|
| |
|
| | def default(val, d):
|
| | return val if exists(val) else d
|
| |
|
| |
|
| |
|
| | class LambdaLayer(nn.Module):
|
| | def __init__(
|
| | self,
|
| | dim,
|
| | *,
|
| | dim_k,
|
| | n = None,
|
| | r = None,
|
| | heads = 4,
|
| | dim_out = None,
|
| | dim_u = 1):
|
| | super().__init__()
|
| | dim_out = default(dim_out, dim)
|
| | self.u = dim_u
|
| | self.heads = heads
|
| |
|
| | assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
|
| | dim_v = dim_out // heads
|
| |
|
| | self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
|
| | self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
|
| | self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
|
| |
|
| | self.norm_q = nn.BatchNorm2d(dim_k * heads)
|
| | self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
|
| |
|
| | self.local_contexts = exists(r)
|
| | if exists(r):
|
| | assert (r % 2) == 1, 'Receptive kernel size should be odd'
|
| | self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
|
| | else:
|
| | assert exists(n), 'You must specify the total sequence length (h x w)'
|
| | self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
|
| |
|
| |
|
| | def forward(self, x):
|
| | b, c, hh, ww, u, h = *x.shape, self.u, self.heads
|
| |
|
| | q = self.to_q(x)
|
| | k = self.to_k(x)
|
| | v = self.to_v(x)
|
| |
|
| | q = self.norm_q(q)
|
| | v = self.norm_v(v)
|
| |
|
| | q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
|
| | k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
|
| | v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)
|
| |
|
| | k = k.softmax(dim=-1)
|
| |
|
| | 位c = einsum('b u k m, b u v m -> b k v', k, v)
|
| | Yc = einsum('b h k n, b k v -> b h v n', q, 位c)
|
| |
|
| | if self.local_contexts:
|
| | v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
|
| | 位p = self.pos_conv(v)
|
| | Yp = einsum('b h k n, b k v n -> b h v n', q, 位p.flatten(3))
|
| | else:
|
| | 位p = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
|
| | Yp = einsum('b h k n, b n k v -> b h v n', q, 位p)
|
| |
|
| | Y = Yc + Yp
|
| | out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
|
| | return out
|
| |
|