trm-vit / backbone.py
detectivejoewest's picture
Upload 4 files
ce78e68 verified
import torch
from einops import rearrange
import torch.nn as nn
from utils import apply_angles_1d, generate_angles_1d, RMSNorm
import torch.nn.functional as F
from utils import build_padding_mask
class TransformerBackbone(nn.Module):
def __init__(self, config):
super().__init__()
self.blocks = nn.ModuleList()
for _ in range(config['depth']):
block = nn.Module()
block.ln1 = nn.LayerNorm(config['dim'])
block.ln2 = nn.LayerNorm(config['dim'])
block.attn = Attention(config['context'], config['dim'], n_heads=config['n_heads'])
block.mlp = MLPGeGLU(config['dim'])
self.blocks.append(block)
def forward(self, x, L):
for block in self.blocks:
x = x + block.attn(block.ln1(x), L)
x = x + block.mlp(block.ln2(x))
return x
class Attention(nn.Module):
def __init__(self, context_length, emb_dim, causal=True, n_heads=8):
super().__init__()
self.causal = causal
self.context_length = context_length
self.n_heads = n_heads
head_dim = emb_dim // n_heads
self.qkv = nn.Linear(emb_dim, 3*emb_dim, bias=False)
self.proj = nn.Linear(emb_dim, emb_dim)
self.register_buffer("freq", generate_angles_1d(context_length, head_dim), persistent=False)
def forward(self, x, L):
B, N, D = x.shape
q, k, v = self.qkv(x).chunk(3, dim=-1)
padding_mask = build_padding_mask(x, L, self.context_length)
padding_mask = padding_mask.unsqueeze(-1).unsqueeze(1)
q = rearrange(q, "B N (h D) -> B h N D", h=self.n_heads)
k = rearrange(k, "B N (h D) -> B h N D", h=self.n_heads)
v = rearrange(v, "B N (h D) -> B h N D", h=self.n_heads)
q = apply_angles_1d(q, self.freq)
k = apply_angles_1d(k, self.freq)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=padding_mask, is_causal=True)
x = rearrange(x, "B h N D -> B N (h D)")
x = self.proj(x)
return x
class MLPGeGLU(nn.Module):
def __init__(self, dim: int, upsample=2, transpose=False):
"""
dim = embedding dimension
tokens = number of tokens per embedding
"""
super().__init__()
self.transpose = transpose
self.dim = dim
self.linearIn = nn.Linear(dim, upsample*dim, bias=True)
self.gate = nn.Linear(dim, upsample*dim, bias=True)
self.linearOut = nn.Linear(upsample*dim, dim, bias=True)
def forward(self, x: torch.Tensor):
"""
Requires input to be B N D where N=tokens
Outputs a singleton for x[-1] (z) of shape B 1 D
Transposes by N, D axis to create a per-feature affine transform
"""
x = rearrange(x, "B N D -> B D N") if self.transpose else x # batch of token vectors to batch of per-token feature vectors
x = self.linearOut(F.gelu(self.linearIn(x)) * self.gate(x))
x = rearrange(x, "B D N -> B N D") if self.transpose else x # recover x,y,z.
return RMSNorm(x)
class MLPSwiGLU(nn.Module):
def __init__(self, dim: int, upsample=2, transpose=False):
"""
dim = embedding dimension
tokens = number of tokens per embedding
"""
super().__init__()
self.transpose = transpose
self.dim = dim
self.linearIn = nn.Linear(dim, upsample*dim, bias=True)
self.gate = nn.Linear(dim, upsample*dim, bias=True)
self.linearOut = nn.Linear(upsample*dim, dim, bias=True)
def forward(self, x: torch.Tensor):
"""
Requires input to be B N D where N=tokens
Outputs a singleton for x[-1] (z) of shape B 1 D
Transposes by N, D axis to create a per-feature affine transform
"""
x = rearrange(x, "B N D -> B D N") if self.transpose else x # batch of token vectors to batch of per-token feature vectors
x = self.linearOut(F.silu(self.linearIn(x)) * self.gate(x))
x = rearrange(x, "B D N -> B N D") if self.transpose else x # recover x,y,z
return RMSNorm(x)