llm / model.py
abersbail's picture
Upload 16 files
7fc99b0 verified
import math
import torch
from torch import nn
from torch.nn import functional as F
from .config import ModelConfig
class CausalSelfAttention(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
if config.n_embd % config.n_heads != 0:
raise ValueError("n_embd must be divisible by n_heads.")
self.n_heads = config.n_heads
self.head_dim = config.n_embd // config.n_heads
self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd)
self.proj = nn.Linear(config.n_embd, config.n_embd)
self.dropout = nn.Dropout(config.dropout)
mask = torch.tril(torch.ones(config.block_size, config.block_size))
self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, channels = x.shape
qkv = self.qkv(x)
q, k, v = qkv.split(channels, dim=2)
q = q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
att = att.masked_fill(self.mask[:, :, :seq_len, :seq_len] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
att = self.dropout(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(batch_size, seq_len, channels)
return self.proj(y)
class FeedForward(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.net = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class Block(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln2 = nn.LayerNorm(config.n_embd)
self.ff = FeedForward(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln1(x))
x = x + self.ff(self.ln2(x))
return x
class TinyTransformerLM(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd)
self.position_embedding = nn.Embedding(config.block_size, config.n_embd)
self.dropout = nn.Dropout(config.dropout)
self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layers)])
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size)
def forward(
self,
idx: torch.Tensor,
targets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
_, seq_len = idx.shape
if seq_len > self.config.block_size:
raise ValueError("Input sequence exceeds block size.")
positions = torch.arange(0, seq_len, device=idx.device)
x = self.token_embedding(idx) + self.position_embedding(positions)
x = self.dropout(x)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
logits = self.head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
)
return logits, loss
@torch.no_grad()
def generate(
self,
idx: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
top_k: int | None = None,
) -> torch.Tensor:
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.config.block_size :]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k is not None:
values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < values[:, [-1]]] = float("-inf")
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, next_token), dim=1)
return idx