ViT_Autoencoder / attention.py
detectivejoewest's picture
Upload 7 files
582b238 verified
from RoPE import apply_angles_2d, generate_angles_2d
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class Attention(nn.Module):
def __init__(self, H,W, emb_dim, n_heads=8):
super().__init__()
self.H = H
self.W = W
self.n_heads = n_heads
head_dim = emb_dim // n_heads
self.qkv = nn.Linear(emb_dim, 3*emb_dim, bias=False)
self.apply_angles_2d = apply_angles_2d
self.proj = nn.Linear(emb_dim, emb_dim)
self.register_buffer("freq", generate_angles_2d(H, W, head_dim), persistent=False)
def forward(self, x):
B, N, D = x.shape
q, k, v = self.qkv(x).chunk(3, dim=-1)
# to 2D
q = rearrange(q, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
k = rearrange(k, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
v = rearrange(v, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
q = apply_angles_2d(q, self.freq)
k = apply_angles_2d(k, self.freq)
v = apply_angles_2d(v, self.freq)
# to 1D
q = rearrange(q, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
k = rearrange(k, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
v = rearrange(v, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B h N D -> B N (h D)")
x = self.proj(x)
return x
class ViTBlock(nn.Module):
def __init__(self, H, W, emb_dim, n_heads=8, dropout=0.1):
self.H, self.W, self.emb_dim = H, W, emb_dim
super().__init__()
self.attn = nn.Sequential(nn.LayerNorm(emb_dim),
Attention(H,W,emb_dim,n_heads=n_heads))
self.MLP = nn.Sequential(nn.LayerNorm(emb_dim),
nn.Linear(emb_dim, emb_dim*4, bias=True),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(emb_dim*4, emb_dim, bias=True),
nn.Dropout(dropout))
def forward(self, x):
assert x.ndim == 3, f"Expected shape [B, N, D], but got shape {x.shape}. You probably passed [B, H, W, D] instead."
assert x.shape == torch.Size([x.shape[0], self.H * self.W, self.emb_dim]), f"Expected shape [B, N, D] -> {torch.Size([x.shape[0], self.H * self.W, self.emb_dim])}, got {x.shape}"
x = x + self.attn(x)
x = x + self.MLP(x)
return x
# Sanity Check :)
print(ViTBlock(64,64,384)(torch.randn(1, 64**2, 384)).shape)