File size: 4,193 Bytes
ce78e68 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | import torch
from einops import rearrange
import torch.nn as nn
from utils import apply_angles_1d, generate_angles_1d, RMSNorm
import torch.nn.functional as F
from utils import build_padding_mask
class TransformerBackbone(nn.Module):
def __init__(self, config):
super().__init__()
self.blocks = nn.ModuleList()
for _ in range(config['depth']):
block = nn.Module()
block.ln1 = nn.LayerNorm(config['dim'])
block.ln2 = nn.LayerNorm(config['dim'])
block.attn = Attention(config['context'], config['dim'], n_heads=config['n_heads'])
block.mlp = MLPGeGLU(config['dim'])
self.blocks.append(block)
def forward(self, x, L):
for block in self.blocks:
x = x + block.attn(block.ln1(x), L)
x = x + block.mlp(block.ln2(x))
return x
class Attention(nn.Module):
def __init__(self, context_length, emb_dim, causal=True, n_heads=8):
super().__init__()
self.causal = causal
self.context_length = context_length
self.n_heads = n_heads
head_dim = emb_dim // n_heads
self.qkv = nn.Linear(emb_dim, 3*emb_dim, bias=False)
self.proj = nn.Linear(emb_dim, emb_dim)
self.register_buffer("freq", generate_angles_1d(context_length, head_dim), persistent=False)
def forward(self, x, L):
B, N, D = x.shape
q, k, v = self.qkv(x).chunk(3, dim=-1)
padding_mask = build_padding_mask(x, L, self.context_length)
padding_mask = padding_mask.unsqueeze(-1).unsqueeze(1)
q = rearrange(q, "B N (h D) -> B h N D", h=self.n_heads)
k = rearrange(k, "B N (h D) -> B h N D", h=self.n_heads)
v = rearrange(v, "B N (h D) -> B h N D", h=self.n_heads)
q = apply_angles_1d(q, self.freq)
k = apply_angles_1d(k, self.freq)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=padding_mask, is_causal=True)
x = rearrange(x, "B h N D -> B N (h D)")
x = self.proj(x)
return x
class MLPGeGLU(nn.Module):
def __init__(self, dim: int, upsample=2, transpose=False):
"""
dim = embedding dimension
tokens = number of tokens per embedding
"""
super().__init__()
self.transpose = transpose
self.dim = dim
self.linearIn = nn.Linear(dim, upsample*dim, bias=True)
self.gate = nn.Linear(dim, upsample*dim, bias=True)
self.linearOut = nn.Linear(upsample*dim, dim, bias=True)
def forward(self, x: torch.Tensor):
"""
Requires input to be B N D where N=tokens
Outputs a singleton for x[-1] (z) of shape B 1 D
Transposes by N, D axis to create a per-feature affine transform
"""
x = rearrange(x, "B N D -> B D N") if self.transpose else x # batch of token vectors to batch of per-token feature vectors
x = self.linearOut(F.gelu(self.linearIn(x)) * self.gate(x))
x = rearrange(x, "B D N -> B N D") if self.transpose else x # recover x,y,z.
return RMSNorm(x)
class MLPSwiGLU(nn.Module):
def __init__(self, dim: int, upsample=2, transpose=False):
"""
dim = embedding dimension
tokens = number of tokens per embedding
"""
super().__init__()
self.transpose = transpose
self.dim = dim
self.linearIn = nn.Linear(dim, upsample*dim, bias=True)
self.gate = nn.Linear(dim, upsample*dim, bias=True)
self.linearOut = nn.Linear(upsample*dim, dim, bias=True)
def forward(self, x: torch.Tensor):
"""
Requires input to be B N D where N=tokens
Outputs a singleton for x[-1] (z) of shape B 1 D
Transposes by N, D axis to create a per-feature affine transform
"""
x = rearrange(x, "B N D -> B D N") if self.transpose else x # batch of token vectors to batch of per-token feature vectors
x = self.linearOut(F.silu(self.linearIn(x)) * self.gate(x))
x = rearrange(x, "B D N -> B N D") if self.transpose else x # recover x,y,z
return RMSNorm(x) |