|
|
from RoPE import apply_angles_2d, generate_angles_2d
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from einops import rearrange
|
|
|
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
def __init__(self, H,W, emb_dim, n_heads=8):
|
|
|
super().__init__()
|
|
|
self.H = H
|
|
|
self.W = W
|
|
|
self.n_heads = n_heads
|
|
|
head_dim = emb_dim // n_heads
|
|
|
self.qkv = nn.Linear(emb_dim, 3*emb_dim, bias=False)
|
|
|
self.apply_angles_2d = apply_angles_2d
|
|
|
self.proj = nn.Linear(emb_dim, emb_dim)
|
|
|
self.register_buffer("freq", generate_angles_2d(H, W, head_dim), persistent=False)
|
|
|
|
|
|
def forward(self, x):
|
|
|
B, N, D = x.shape
|
|
|
q, k, v = self.qkv(x).chunk(3, dim=-1)
|
|
|
|
|
|
|
|
|
q = rearrange(q, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
|
|
|
k = rearrange(k, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
|
|
|
v = rearrange(v, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
|
|
|
|
|
|
q = apply_angles_2d(q, self.freq)
|
|
|
k = apply_angles_2d(k, self.freq)
|
|
|
v = apply_angles_2d(v, self.freq)
|
|
|
|
|
|
|
|
|
q = rearrange(q, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
|
|
|
k = rearrange(k, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
|
|
|
v = rearrange(v, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
|
|
|
|
|
|
x = F.scaled_dot_product_attention(q, k, v)
|
|
|
x = rearrange(x, "B h N D -> B N (h D)")
|
|
|
x = self.proj(x)
|
|
|
return x
|
|
|
|
|
|
class ViTBlock(nn.Module):
|
|
|
def __init__(self, H, W, emb_dim, n_heads=8, dropout=0.1):
|
|
|
self.H, self.W, self.emb_dim = H, W, emb_dim
|
|
|
super().__init__()
|
|
|
self.attn = nn.Sequential(nn.LayerNorm(emb_dim),
|
|
|
Attention(H,W,emb_dim,n_heads=n_heads))
|
|
|
self.MLP = nn.Sequential(nn.LayerNorm(emb_dim),
|
|
|
nn.Linear(emb_dim, emb_dim*4, bias=True),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(emb_dim*4, emb_dim, bias=True),
|
|
|
nn.Dropout(dropout))
|
|
|
def forward(self, x):
|
|
|
assert x.ndim == 3, f"Expected shape [B, N, D], but got shape {x.shape}. You probably passed [B, H, W, D] instead."
|
|
|
assert x.shape == torch.Size([x.shape[0], self.H * self.W, self.emb_dim]), f"Expected shape [B, N, D] -> {torch.Size([x.shape[0], self.H * self.W, self.emb_dim])}, got {x.shape}"
|
|
|
x = x + self.attn(x)
|
|
|
x = x + self.MLP(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
print(ViTBlock(64,64,384)(torch.randn(1, 64**2, 384)).shape) |