i_like_purple / model.py
dasdasddds's picture
Upload 16 files
93783dd verified
"""
GPT-300M Model Architecture
============================
A decoder-only transformer built entirely from scratch in PyTorch.
Architecture features:
- Pre-LayerNorm transformer blocks
- Rotary Position Embeddings (RoPE)
- Multi-Head Self-Attention with causal masking
- GELU activation in feed-forward layers
- Optional weight tying (token embeddings ↔ LM head)
- KV-Cache for efficient autoregressive generation
- Flash Attention support (PyTorch 2.0+)
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from config import GPT300MConfig
# ═══════════════════════════════════════════════════════════════════════
# ROTARY POSITION EMBEDDINGS (RoPE)
# ═══════════════════════════════════════════════════════════════════════
class RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (Su et al., 2021)."""
def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Precompute cos/sin tables
t = torch.arange(max_seq_len, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, seq_len: int, offset: int = 0):
return (
self.cos_cached[offset : offset + seq_len],
self.sin_cached[offset : offset + seq_len],
)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate the second half of the last dimension."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def apply_rotary_emb(
q: torch.Tensor, k: torch.Tensor,
cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary embeddings to query and key tensors."""
# cos/sin shape: [seq_len, head_dim] β†’ [1, 1, seq_len, head_dim]
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
q_rot = q * cos + rotate_half(q) * sin
k_rot = k * cos + rotate_half(k) * sin
return q_rot, k_rot
# ═══════════════════════════════════════════════════════════════════════
# RMSNORM (faster alternative to LayerNorm)
# ═══════════════════════════════════════════════════════════════════════
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (x.float() * norm).type_as(x) * self.weight
# ═══════════════════════════════════════════════════════════════════════
# MULTI-HEAD SELF-ATTENTION
# ═══════════════════════════════════════════════════════════════════════
class MultiHeadAttention(nn.Module):
"""Multi-Head Self-Attention with causal masking and optional KV-cache."""
def __init__(self, config: GPT300MConfig):
super().__init__()
self.n_heads = config.n_heads
self.head_dim = config.head_dim
self.d_model = config.d_model
self.dropout = config.dropout
# Q, K, V projections (fused for efficiency)
self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias)
# Output projection
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
# Check for Flash Attention support
self.flash_attn = hasattr(F, "scaled_dot_product_attention")
def forward(
self,
x: torch.Tensor,
cos: Optional[torch.Tensor] = None,
sin: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
B, T, C = x.shape
# Project to Q, K, V
qkv = self.qkv_proj(x)
q, k, v = qkv.split(self.d_model, dim=-1)
# Reshape: [B, T, n_heads, head_dim] β†’ [B, n_heads, T, head_dim]
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
# Apply RoPE
if cos is not None and sin is not None:
q, k = apply_rotary_emb(q, k, cos, sin)
# KV-Cache for generation
if kv_cache is not None:
k_prev, v_prev = kv_cache
k = torch.cat([k_prev, k], dim=2)
v = torch.cat([v_prev, v], dim=2)
new_cache = (k, v) if use_cache else None
# Attention
if self.flash_attn and not use_cache:
# Use PyTorch's efficient SDPA
attn_out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True if mask is None else False,
)
else:
# Manual attention for compatibility / KV-cache
scale = 1.0 / math.sqrt(self.head_dim)
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
else:
# Causal mask
T_q, T_k = q.size(2), k.size(2)
causal = torch.tril(torch.ones(T_q, T_k, device=x.device, dtype=torch.bool))
# For KV-cache, the causal mask must align with key length
causal = causal[-T:, :] # last T rows
scores = scores.masked_fill(~causal.unsqueeze(0).unsqueeze(0), float("-inf"))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.attn_dropout(attn_weights)
attn_out = torch.matmul(attn_weights, v)
# Reshape back and project
attn_out = attn_out.transpose(1, 2).contiguous().view(B, -1, self.d_model)
out = self.resid_dropout(self.out_proj(attn_out))
return out, new_cache
# ═══════════════════════════════════════════════════════════════════════
# FEED-FORWARD NETWORK
# ═══════════════════════════════════════════════════════════════════════
class FeedForward(nn.Module):
"""Position-wise Feed-Forward Network with GELU activation."""
def __init__(self, config: GPT300MConfig):
super().__init__()
self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=config.bias)
self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
if config.activation == "gelu":
self.act = nn.GELU()
elif config.activation == "swiglu":
self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=config.bias)
self.act = nn.SiLU()
else:
raise ValueError(f"Unknown activation: {config.activation}")
self.use_swiglu = config.activation == "swiglu"
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_swiglu:
return self.dropout(self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)))
else:
return self.dropout(self.down_proj(self.act(self.up_proj(x))))
# ═══════════════════════════════════════════════════════════════════════
# TRANSFORMER BLOCK
# ═══════════════════════════════════════════════════════════════════════
class TransformerBlock(nn.Module):
"""Pre-norm Transformer block: LayerNorm β†’ Attention β†’ Residual β†’ LayerNorm β†’ FFN β†’ Residual."""
def __init__(self, config: GPT300MConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.ln1 = RMSNorm(config.d_model, eps=config.norm_eps)
self.attn = MultiHeadAttention(config)
self.ln2 = RMSNorm(config.d_model, eps=config.norm_eps)
self.ffn = FeedForward(config)
def forward(
self,
x: torch.Tensor,
cos: Optional[torch.Tensor] = None,
sin: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Pre-norm attention with residual
residual = x
x = self.ln1(x)
attn_out, new_cache = self.attn(x, cos, sin, mask, kv_cache, use_cache)
x = residual + attn_out
# Pre-norm FFN with residual
x = x + self.ffn(self.ln2(x))
return x, new_cache
# ═══════════════════════════════════════════════════════════════════════
# GPT-300M: THE FULL MODEL
# ═══════════════════════════════════════════════════════════════════════
class GPT300M(nn.Module):
"""
GPT-300M: A 300-million parameter autoregressive language model.
Architecture:
Token Embedding β†’ [Transformer Block Γ— 24] β†’ RMSNorm β†’ LM Head
Each Transformer Block:
RMSNorm β†’ Multi-Head Attention (+ RoPE) β†’ Residual
β†’ RMSNorm β†’ Feed-Forward (GELU) β†’ Residual
"""
def __init__(self, config: GPT300MConfig):
super().__init__()
self.config = config
# ── Embeddings ───────────────────────────────────────────────
self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
self.drop = nn.Dropout(config.dropout)
# Rotary embeddings
if config.rope:
self.rotary = RotaryEmbedding(
config.head_dim, config.max_seq_len, config.rope_theta
)
else:
self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
# ── Transformer Blocks ───────────────────────────────────────
self.layers = nn.ModuleList([
TransformerBlock(config, layer_idx=i)
for i in range(config.n_layers)
])
# ── Output ───────────────────────────────────────────────────
self.ln_f = RMSNorm(config.d_model, eps=config.norm_eps)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Weight tying
if config.tie_weights:
self.lm_head.weight = self.token_emb.weight
# Initialize weights
self.apply(self._init_weights)
# Scale residual projections
for pn, p in self.named_parameters():
if pn.endswith("out_proj.weight") or pn.endswith("down_proj.weight"):
nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layers))
def _init_weights(self, module: nn.Module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
input_ids: torch.Tensor,
targets: Optional[torch.Tensor] = None,
kv_caches: Optional[list] = None,
use_cache: bool = False,
position_offset: int = 0,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[list]]:
"""
Forward pass.
Args:
input_ids: [B, T] token indices
targets: [B, T] target token indices for loss computation
kv_caches: List of KV-cache tuples, one per layer
use_cache: Whether to return updated KV-caches
position_offset: Offset for position embeddings (for KV-cache generation)
Returns:
logits: [B, T, vocab_size]
loss: scalar loss if targets provided, else None
new_caches: Updated KV-caches if use_cache=True
"""
B, T = input_ids.shape
assert T <= self.config.max_seq_len, (
f"Sequence length {T} exceeds max {self.config.max_seq_len}"
)
# Token embeddings
x = self.token_emb(input_ids) # [B, T, d_model]
# Position information
if self.config.rope:
cos, sin = self.rotary(T, offset=position_offset)
else:
positions = torch.arange(position_offset, position_offset + T, device=input_ids.device)
x = x + self.pos_emb(positions)
cos, sin = None, None
x = self.drop(x)
# Transformer blocks
new_caches = [] if use_cache else None
for i, layer in enumerate(self.layers):
cache_i = kv_caches[i] if kv_caches is not None else None
x, new_cache = layer(x, cos, sin, kv_cache=cache_i, use_cache=use_cache)
if use_cache:
new_caches.append(new_cache)
# Final norm and LM head
x = self.ln_f(x)
logits = self.lm_head(x) # [B, T, vocab_size]
# Loss
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, self.config.vocab_size),
targets.view(-1),
ignore_index=self.config.pad_token_id,
)
return logits, loss, new_caches
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 256,
temperature: float = 0.7,
top_k: int = 50,
top_p: float = 0.9,
repetition_penalty: float = 1.1,
eos_token_id: Optional[int] = None,
) -> torch.Tensor:
"""
Autoregressive generation with KV-cache.
Args:
input_ids: [B, T] prompt token IDs
max_new_tokens: Maximum number of tokens to generate
temperature: Sampling temperature
top_k: Top-k sampling
top_p: Nucleus sampling threshold
repetition_penalty: Penalty for repeated tokens
eos_token_id: Stop generation when this token is produced
Returns:
[B, T + max_new_tokens] generated token IDs
"""
self.eval()
B, T = input_ids.shape
device = input_ids.device
# Initial forward pass to populate KV-cache
logits, _, kv_caches = self.forward(input_ids, use_cache=True)
generated = input_ids
all_token_ids = input_ids.tolist()[0] if B == 1 else []
for step in range(max_new_tokens):
# Get logits for the last token
next_logits = logits[:, -1, :] # [B, vocab_size]
# Repetition penalty
if repetition_penalty != 1.0 and B == 1:
for token_id in set(all_token_ids):
if next_logits[0, token_id] > 0:
next_logits[0, token_id] /= repetition_penalty
else:
next_logits[0, token_id] *= repetition_penalty
# Temperature
if temperature > 0:
next_logits = next_logits / temperature
# Top-k filtering
if top_k > 0:
topk_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
next_logits[next_logits < topk_vals[:, -1:]] = float("-inf")
# Top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(next_logits, descending=True)
cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_mask = cumprobs - F.softmax(sorted_logits, dim=-1) >= top_p
sorted_logits[sorted_mask] = float("-inf")
next_logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
# Greedy
next_token = next_logits.argmax(dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
if B == 1:
all_token_ids.append(next_token.item())
# Stop on EOS
if eos_token_id is not None and next_token.item() == eos_token_id:
break
# Forward pass with KV-cache (only the new token)
position_offset = generated.size(1) - 1
logits, _, kv_caches = self.forward(
next_token,
kv_caches=kv_caches,
use_cache=True,
position_offset=position_offset,
)
return generated
def count_parameters(self, trainable_only: bool = True) -> int:
"""Count model parameters."""
if trainable_only:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
return sum(p.numel() for p in self.parameters())
def model_summary(self) -> str:
"""Print a human-readable model summary."""
total = self.count_parameters(trainable_only=False)
trainable = self.count_parameters(trainable_only=True)
lines = [
"=" * 60,
" GPT-300M Model Summary",
"=" * 60,
f" Total parameters: {total:>15,}",
f" Trainable parameters: {trainable:>15,}",
f" d_model: {self.config.d_model:>15}",
f" n_heads: {self.config.n_heads:>15}",
f" n_layers: {self.config.n_layers:>15}",
f" d_ff: {self.config.d_ff:>15}",
f" vocab_size: {self.config.vocab_size:>15}",
f" max_seq_len: {self.config.max_seq_len:>15}",
f" RoPE: {'Yes':>15}",
f" Weight tying: {'Yes' if self.config.tie_weights else 'No':>15}",
f" Flash Attention: {'Yes' if self.layers[0].attn.flash_attn else 'No':>15}",
"=" * 60,
]
return "\n".join(lines)
# ═══════════════════════════════════════════════════════════════════════
# QUICK TEST
# ═══════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
from config import gpt_tiny
# Use tiny config for testing
cfg = gpt_tiny()
model = GPT300M(cfg)
print(model.model_summary())
# Test forward pass
x = torch.randint(0, cfg.vocab_size, (2, 32))
targets = torch.randint(0, cfg.vocab_size, (2, 32))
logits, loss, _ = model(x, targets=targets)
print(f"\nForward pass OK: logits={logits.shape}, loss={loss.item():.4f}")
# Test generation
prompt = torch.randint(0, cfg.vocab_size, (1, 8))
gen = model.generate(prompt, max_new_tokens=16, temperature=0.8)
print(f"Generation OK: {gen.shape}")