AlpineLLM-App / model /transformer_decoder.py
Borzyszkowski
ALP-1: web app for Alpine LLM'
bb2fa48
""" Architecture of the TransformerDecoder """
import torch
import torch.nn as nn
from torch.nn import functional as F
class TransformerDecoder(nn.Module):
""" GPT-style decoder-only language model """
def __init__(self, vocab_size, hyperparam_cfg, device):
super(TransformerDecoder, self).__init__()
self.device = device
# model hyperparameters
embedding_dim = hyperparam_cfg.embedding_dim
num_layers = hyperparam_cfg.num_layers
self.context_len = hyperparam_cfg.context_len
# lookup table of tokens is used so that each token reads the logits for the next token
self.token_embedding_table = nn.Embedding(vocab_size, embedding_dim)
# pos embedding table adds information about the position of each token in the context
self.pos_embedding_table = nn.Embedding(self.context_len, embedding_dim)
# stack multiple transformer blocks to increase model capacity
self.tfblocks = nn.Sequential(*[TFBlock(hyperparam_cfg) for _ in range(num_layers)])
# final normalization and linear layer to produce logits for each token in the vocabulary
self.ln_f = nn.LayerNorm(embedding_dim)
self.lm_head = nn.Linear(embedding_dim, vocab_size)
# better weight initialization for
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx):
"""
The forward pass of the model returns the logits of shape (B,T,C)
# where: B=batch_size T=context_len C=vocab_size
"""
# idx is a (B,T) tensor of integers which are indices in the current context
B, T = idx.shape
token_embd = self.token_embedding_table(idx) # (batch_size, context_len, embedding_dim)
positions = torch.arange(T).to(self.device) # tensor([0, 1, 2, ..., T-1])
pos_embd = self.pos_embedding_table(positions) # (context_len, embedding_dim)
x = token_embd + pos_embd # (batch_size, context_len, embedding_dim)
x = self.tfblocks(x) # (batch_size, context_len, embedding_dim)
x = self.ln_f(x) # (batch_size, context_len, embedding_dim)
logits = self.lm_head(x) # (batch_size, context_len, vocab_size)
return logits
def generate(self, idx, max_new_tokens):
""" Generate new tokens from the model """
for _ in range(max_new_tokens):
# crop idx to the last context_len tokens
idx_context = idx[:, -self.context_len:]
# get the predictions
logits = self(idx_context) # (B,T,C)
# focus only on the last time step
logits = logits[:, -1, :] # (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution to get the next token index
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx
class TFBlock(nn.Module):
""" Single transformer block: communication (attention) followed by computation (dense) """
def __init__(self, hyperparam_cfg):
super(TFBlock, self).__init__()
# model hyperparameters
embedding_dim = hyperparam_cfg.embedding_dim
num_heads = hyperparam_cfg.num_heads
context_len = hyperparam_cfg.context_len
dropout = hyperparam_cfg.dropout
# size of MultiHeadAttention matches the embedding dimension (num_heads * head_size = embedding_dim)
self.sa_heads = MultiHeadAttention(num_heads=num_heads,
head_size=embedding_dim // num_heads,
embedding_dim=embedding_dim,
context_len=context_len,
dropout=dropout)
self.feed_forward = FeedForward(embedding_dim, dropout)
self.ln1 = nn.LayerNorm(embedding_dim)
self.ln2 = nn.LayerNorm(embedding_dim)
def forward(self, x):
# both attention and feed-forward layers have residual connections
x = x + self.sa_heads(self.ln1(x))
x = x + self.feed_forward(self.ln2(x))
return x
class MultiHeadAttention(nn.Module):
""" Multiple heads of self-attention in parallel """
def __init__(self, num_heads, head_size, embedding_dim, context_len, dropout):
super(MultiHeadAttention, self).__init__()
self.heads = nn.ModuleList([AttentionHead(embedding_dim, head_size, context_len, dropout) for _ in range(num_heads)])
# projection is needed due to residual connection to bring all heads back to embedding_dim
self.projection = nn.Linear(num_heads * head_size, embedding_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = torch.cat([h(x) for h in self.heads], dim=-1) # (batch, context_len, num_heads * head_size)
out = self.dropout(self.projection(x)) # (batch, context_len, embedding_dim)
return out
class AttentionHead(nn.Module):
""" One head of self-attention """
def __init__(self, embedding_dim, head_size, context_len, dropout):
super(AttentionHead, self).__init__()
self.queries = nn.Linear(embedding_dim, head_size, bias=False)
self.keys = nn.Linear(embedding_dim, head_size, bias=False)
self.values = nn.Linear(embedding_dim, head_size, bias=False)
self.dropout = nn.Dropout(dropout)
# lower triangular matrix is used to mask out future tokens in the attention mechanism
self.register_buffer("mask", torch.tril(torch.ones(context_len, context_len)))
def forward(self, x):
B, T, C = x.shape # (batch_size, context_len, embedding_dim)
q = self.queries(x) # (batch, context_len, head_size)
k = self.keys(x) # (batch, context_len, head_size)
v = self.values(x) # (batch, context_len, head_size)
# compute attention matrix (key and query dot product)
weights = q @ k.transpose(-2, -1) # (B,T,C) @ (B,C,T) -> (B,T,T)
# scale by sqrt(head_size) to prevent large dot products (stabilizes gradients)
weights = weights * C**-0.5
# mask replaces 0 with -inf and keeps 1 as is (ones are on and below diagonal; zeros above diagonal)
weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
# softmax along the last dimension to get probabilities per row
weights = F.softmax(weights, dim=-1)
weights = self.dropout(weights)
output = weights @ v # matrix multiplication (T,T) @ (B,T,C) -> (B,T,C) = (batch, context_len, head_size)
return output
class FeedForward(nn.Module):
""" Single feed-forward layer followed by a non-linearity """
def __init__(self, embedding_dim, dropout):
super(FeedForward, self).__init__()
# embedding_dim is multiplied by 4 to reflect the original transformer paper
self.net = nn.Sequential(
nn.Linear(embedding_dim, embedding_dim * 4),
nn.ReLU(),
nn.Linear(embedding_dim * 4, embedding_dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)