| import os |
| import torch |
| from torch import nn |
| from einops import rearrange, repeat |
| from torch import einsum |
|
|
|
|
| class PerceiverAttention(nn.Module): |
| def __init__( |
| self, |
| *, |
| dim, |
| dim_head=64, |
| heads=8 |
| ): |
| super().__init__() |
| self.scale = dim_head ** -0.5 |
| self.heads = heads |
| inner_dim = dim_head * heads |
|
|
| self.norm_media = nn.LayerNorm(dim) |
| self.norm_learns = nn.LayerNorm(dim) |
|
|
| self.to_q = nn.Linear(dim, inner_dim, bias=False) |
| self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) |
| self.to_out = nn.Linear(inner_dim, dim, bias=False) |
|
|
| def forward(self, x, learns): |
| x = self.norm_media(x) |
| learns = self.norm_learns(learns) |
|
|
| b, n, h = *x.shape[:2], self.heads |
|
|
| q = self.to_q(learns) |
|
|
| |
| kv_input = torch.cat((x, learns), dim=-2) |
| k, v = self.to_kv(kv_input).chunk(2, dim=-1) |
|
|
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) |
|
|
| q = q * self.scale |
|
|
| |
| sim = einsum('b h i d, b h j d -> b h i j', q, k) |
| sim = sim - sim.amax(dim=-1, keepdim=True).detach() |
| attn = sim.softmax(dim=-1) |
|
|
| out = einsum('b h i j, b h j d -> b h i d', attn, v) |
| out = rearrange(out, 'b h n d -> b n (h d)') |
| return self.to_out(out) |
|
|
|
|
| class PerceiverResampler(nn.Module): |
| def __init__( |
| self, |
| *, |
| dim, |
| depth=6, |
| dim_head=64, |
| heads=8, |
| num_learns=3, |
| ff_mult=4, |
| ): |
| super().__init__() |
| self.learns = nn.Parameter(torch.randn(num_learns, dim)) |
|
|
| self.layers = nn.ModuleList([]) |
| for _ in range(depth): |
| self.layers.append( |
| nn.ModuleList( |
| [ |
| PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), |
| FeedForward(dim=dim, mult=ff_mult), |
| ] |
| ) |
| ) |
|
|
| self.norm = nn.LayerNorm(dim) |
|
|
| def forward(self, x): |
| """ |
| Args: |
| x (torch.Tensor): image features |
| shape (b, 256, 4096) |
| Returns: |
| shape (b, 3, 4096) where 3 is self.num_learns |
| """ |
| b, n, d = x.shape |
|
|
| |
| learns = repeat(self.learns, "n d -> b n d", b=b) |
|
|
| |
| for attn, ff in self.layers: |
| |
| learns = attn(x, learns) + learns |
| learns = ff(learns) + learns |
|
|
| return self.norm(learns) |
| |
| class MLP(nn.Module): |
| def __init__(self, input_dim, hidden_mult=4): |
| super().__init__() |
| self.ff1 = FeedForward_2(input_dim, input_dim, hidden_mult) |
| self.ff2 = FeedForward_2(input_dim, 3, hidden_mult) |
|
|
| def forward(self, x): |
| |
| x = x.permute(0, 2, 1) |
| x = self.ff1(x) |
| x = self.ff2(x) |
| |
| x = x.permute(0, 2, 1) |
| return x |
| |
| class MLP_6763(nn.Module): |
| def __init__(self, input_dim, output_dim, hidden_mult=2): |
| super().__init__() |
| self.ff1 = FeedForward_2(input_dim, output_dim, hidden_mult) |
| self.ff2 = FeedForward_2(output_dim, output_dim, hidden_mult) |
|
|
| def forward(self, x): |
| b, n, d = x.shape |
| x = x.view(b, -1) |
| x = self.ff1(x) |
| x = self.ff2(x) |
| return x |
|
|
| class FeedForward(nn.Module): |
| def __init__(self, dim, mult=4): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.LayerNorm(dim), |
| nn.Linear(dim, dim * mult), |
| nn.GELU(), |
| nn.Linear(dim * mult, dim), |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
| |
| class FeedForward_2(nn.Module): |
| def __init__(self, input_dim, output_dim, mult=4): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.LayerNorm(input_dim), |
| nn.Linear(input_dim, input_dim * mult), |
| nn.GELU(), |
| nn.Linear(input_dim * mult, output_dim), |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |