|
|
import torch
|
|
|
from einops import rearrange
|
|
|
import torch.nn as nn
|
|
|
from utils import apply_angles_1d, generate_angles_1d, RMSNorm
|
|
|
import torch.nn.functional as F
|
|
|
from utils import build_padding_mask
|
|
|
|
|
|
class TransformerBackbone(nn.Module):
|
|
|
def __init__(self, config):
|
|
|
super().__init__()
|
|
|
self.blocks = nn.ModuleList()
|
|
|
for _ in range(config['depth']):
|
|
|
block = nn.Module()
|
|
|
block.ln1 = nn.LayerNorm(config['dim'])
|
|
|
block.ln2 = nn.LayerNorm(config['dim'])
|
|
|
block.attn = Attention(config['context'], config['dim'], n_heads=config['n_heads'])
|
|
|
block.mlp = MLPGeGLU(config['dim'])
|
|
|
self.blocks.append(block)
|
|
|
|
|
|
def forward(self, x, L):
|
|
|
for block in self.blocks:
|
|
|
x = x + block.attn(block.ln1(x), L)
|
|
|
x = x + block.mlp(block.ln2(x))
|
|
|
return x
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
def __init__(self, context_length, emb_dim, causal=True, n_heads=8):
|
|
|
super().__init__()
|
|
|
self.causal = causal
|
|
|
self.context_length = context_length
|
|
|
self.n_heads = n_heads
|
|
|
head_dim = emb_dim // n_heads
|
|
|
self.qkv = nn.Linear(emb_dim, 3*emb_dim, bias=False)
|
|
|
self.proj = nn.Linear(emb_dim, emb_dim)
|
|
|
self.register_buffer("freq", generate_angles_1d(context_length, head_dim), persistent=False)
|
|
|
|
|
|
def forward(self, x, L):
|
|
|
B, N, D = x.shape
|
|
|
q, k, v = self.qkv(x).chunk(3, dim=-1)
|
|
|
padding_mask = build_padding_mask(x, L, self.context_length)
|
|
|
padding_mask = padding_mask.unsqueeze(-1).unsqueeze(1)
|
|
|
q = rearrange(q, "B N (h D) -> B h N D", h=self.n_heads)
|
|
|
k = rearrange(k, "B N (h D) -> B h N D", h=self.n_heads)
|
|
|
v = rearrange(v, "B N (h D) -> B h N D", h=self.n_heads)
|
|
|
|
|
|
q = apply_angles_1d(q, self.freq)
|
|
|
k = apply_angles_1d(k, self.freq)
|
|
|
x = F.scaled_dot_product_attention(q, k, v, attn_mask=padding_mask, is_causal=True)
|
|
|
x = rearrange(x, "B h N D -> B N (h D)")
|
|
|
x = self.proj(x)
|
|
|
return x
|
|
|
|
|
|
class MLPGeGLU(nn.Module):
|
|
|
def __init__(self, dim: int, upsample=2, transpose=False):
|
|
|
"""
|
|
|
dim = embedding dimension
|
|
|
tokens = number of tokens per embedding
|
|
|
"""
|
|
|
super().__init__()
|
|
|
self.transpose = transpose
|
|
|
self.dim = dim
|
|
|
self.linearIn = nn.Linear(dim, upsample*dim, bias=True)
|
|
|
self.gate = nn.Linear(dim, upsample*dim, bias=True)
|
|
|
self.linearOut = nn.Linear(upsample*dim, dim, bias=True)
|
|
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
|
"""
|
|
|
Requires input to be B N D where N=tokens
|
|
|
Outputs a singleton for x[-1] (z) of shape B 1 D
|
|
|
Transposes by N, D axis to create a per-feature affine transform
|
|
|
"""
|
|
|
x = rearrange(x, "B N D -> B D N") if self.transpose else x
|
|
|
x = self.linearOut(F.gelu(self.linearIn(x)) * self.gate(x))
|
|
|
x = rearrange(x, "B D N -> B N D") if self.transpose else x
|
|
|
return RMSNorm(x)
|
|
|
|
|
|
class MLPSwiGLU(nn.Module):
|
|
|
def __init__(self, dim: int, upsample=2, transpose=False):
|
|
|
"""
|
|
|
dim = embedding dimension
|
|
|
tokens = number of tokens per embedding
|
|
|
"""
|
|
|
super().__init__()
|
|
|
self.transpose = transpose
|
|
|
self.dim = dim
|
|
|
self.linearIn = nn.Linear(dim, upsample*dim, bias=True)
|
|
|
self.gate = nn.Linear(dim, upsample*dim, bias=True)
|
|
|
self.linearOut = nn.Linear(upsample*dim, dim, bias=True)
|
|
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
|
"""
|
|
|
Requires input to be B N D where N=tokens
|
|
|
Outputs a singleton for x[-1] (z) of shape B 1 D
|
|
|
Transposes by N, D axis to create a per-feature affine transform
|
|
|
"""
|
|
|
x = rearrange(x, "B N D -> B D N") if self.transpose else x
|
|
|
x = self.linearOut(F.silu(self.linearIn(x)) * self.gate(x))
|
|
|
x = rearrange(x, "B D N -> B N D") if self.transpose else x
|
|
|
return RMSNorm(x) |