|
|
""" |
|
|
GPT-300M Model Architecture |
|
|
============================ |
|
|
A decoder-only transformer built entirely from scratch in PyTorch. |
|
|
|
|
|
Architecture features: |
|
|
- Pre-LayerNorm transformer blocks |
|
|
- Rotary Position Embeddings (RoPE) |
|
|
- Multi-Head Self-Attention with causal masking |
|
|
- GELU activation in feed-forward layers |
|
|
- Optional weight tying (token embeddings β LM head) |
|
|
- KV-Cache for efficient autoregressive generation |
|
|
- Flash Attention support (PyTorch 2.0+) |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from config import GPT300MConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
|
"""Rotary Position Embedding (Su et al., 2021).""" |
|
|
|
|
|
def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0): |
|
|
super().__init__() |
|
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
|
|
|
t = torch.arange(max_seq_len, dtype=torch.float32) |
|
|
freqs = torch.outer(t, inv_freq) |
|
|
emb = torch.cat([freqs, freqs], dim=-1) |
|
|
self.register_buffer("cos_cached", emb.cos(), persistent=False) |
|
|
self.register_buffer("sin_cached", emb.sin(), persistent=False) |
|
|
|
|
|
def forward(self, seq_len: int, offset: int = 0): |
|
|
return ( |
|
|
self.cos_cached[offset : offset + seq_len], |
|
|
self.sin_cached[offset : offset + seq_len], |
|
|
) |
|
|
|
|
|
|
|
|
def rotate_half(x: torch.Tensor) -> torch.Tensor: |
|
|
"""Rotate the second half of the last dimension.""" |
|
|
x1, x2 = x.chunk(2, dim=-1) |
|
|
return torch.cat([-x2, x1], dim=-1) |
|
|
|
|
|
|
|
|
def apply_rotary_emb( |
|
|
q: torch.Tensor, k: torch.Tensor, |
|
|
cos: torch.Tensor, sin: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Apply rotary embeddings to query and key tensors.""" |
|
|
|
|
|
cos = cos.unsqueeze(0).unsqueeze(0) |
|
|
sin = sin.unsqueeze(0).unsqueeze(0) |
|
|
q_rot = q * cos + rotate_half(q) * sin |
|
|
k_rot = k * cos + rotate_half(k) * sin |
|
|
return q_rot, k_rot |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
"""Root Mean Square Layer Normalization.""" |
|
|
|
|
|
def __init__(self, dim: int, eps: float = 1e-5): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() |
|
|
return (x.float() * norm).type_as(x) * self.weight |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
|
"""Multi-Head Self-Attention with causal masking and optional KV-cache.""" |
|
|
|
|
|
def __init__(self, config: GPT300MConfig): |
|
|
super().__init__() |
|
|
self.n_heads = config.n_heads |
|
|
self.head_dim = config.head_dim |
|
|
self.d_model = config.d_model |
|
|
self.dropout = config.dropout |
|
|
|
|
|
|
|
|
self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias) |
|
|
|
|
|
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias) |
|
|
|
|
|
self.attn_dropout = nn.Dropout(config.dropout) |
|
|
self.resid_dropout = nn.Dropout(config.dropout) |
|
|
|
|
|
|
|
|
self.flash_attn = hasattr(F, "scaled_dot_product_attention") |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
cos: Optional[torch.Tensor] = None, |
|
|
sin: Optional[torch.Tensor] = None, |
|
|
mask: Optional[torch.Tensor] = None, |
|
|
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
use_cache: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
|
|
B, T, C = x.shape |
|
|
|
|
|
|
|
|
qkv = self.qkv_proj(x) |
|
|
q, k, v = qkv.split(self.d_model, dim=-1) |
|
|
|
|
|
|
|
|
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
if cos is not None and sin is not None: |
|
|
q, k = apply_rotary_emb(q, k, cos, sin) |
|
|
|
|
|
|
|
|
if kv_cache is not None: |
|
|
k_prev, v_prev = kv_cache |
|
|
k = torch.cat([k_prev, k], dim=2) |
|
|
v = torch.cat([v_prev, v], dim=2) |
|
|
|
|
|
new_cache = (k, v) if use_cache else None |
|
|
|
|
|
|
|
|
if self.flash_attn and not use_cache: |
|
|
|
|
|
attn_out = F.scaled_dot_product_attention( |
|
|
q, k, v, |
|
|
attn_mask=mask, |
|
|
dropout_p=self.dropout if self.training else 0.0, |
|
|
is_causal=True if mask is None else False, |
|
|
) |
|
|
else: |
|
|
|
|
|
scale = 1.0 / math.sqrt(self.head_dim) |
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) * scale |
|
|
|
|
|
if mask is not None: |
|
|
scores = scores.masked_fill(mask == 0, float("-inf")) |
|
|
else: |
|
|
|
|
|
T_q, T_k = q.size(2), k.size(2) |
|
|
causal = torch.tril(torch.ones(T_q, T_k, device=x.device, dtype=torch.bool)) |
|
|
|
|
|
causal = causal[-T:, :] |
|
|
scores = scores.masked_fill(~causal.unsqueeze(0).unsqueeze(0), float("-inf")) |
|
|
|
|
|
attn_weights = F.softmax(scores, dim=-1) |
|
|
attn_weights = self.attn_dropout(attn_weights) |
|
|
attn_out = torch.matmul(attn_weights, v) |
|
|
|
|
|
|
|
|
attn_out = attn_out.transpose(1, 2).contiguous().view(B, -1, self.d_model) |
|
|
out = self.resid_dropout(self.out_proj(attn_out)) |
|
|
|
|
|
return out, new_cache |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
"""Position-wise Feed-Forward Network with GELU activation.""" |
|
|
|
|
|
def __init__(self, config: GPT300MConfig): |
|
|
super().__init__() |
|
|
self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=config.bias) |
|
|
self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=config.bias) |
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
|
|
if config.activation == "gelu": |
|
|
self.act = nn.GELU() |
|
|
elif config.activation == "swiglu": |
|
|
self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=config.bias) |
|
|
self.act = nn.SiLU() |
|
|
else: |
|
|
raise ValueError(f"Unknown activation: {config.activation}") |
|
|
|
|
|
self.use_swiglu = config.activation == "swiglu" |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
if self.use_swiglu: |
|
|
return self.dropout(self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))) |
|
|
else: |
|
|
return self.dropout(self.down_proj(self.act(self.up_proj(x)))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
"""Pre-norm Transformer block: LayerNorm β Attention β Residual β LayerNorm β FFN β Residual.""" |
|
|
|
|
|
def __init__(self, config: GPT300MConfig, layer_idx: int): |
|
|
super().__init__() |
|
|
self.layer_idx = layer_idx |
|
|
self.ln1 = RMSNorm(config.d_model, eps=config.norm_eps) |
|
|
self.attn = MultiHeadAttention(config) |
|
|
self.ln2 = RMSNorm(config.d_model, eps=config.norm_eps) |
|
|
self.ffn = FeedForward(config) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
cos: Optional[torch.Tensor] = None, |
|
|
sin: Optional[torch.Tensor] = None, |
|
|
mask: Optional[torch.Tensor] = None, |
|
|
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
use_cache: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
|
|
|
|
|
residual = x |
|
|
x = self.ln1(x) |
|
|
attn_out, new_cache = self.attn(x, cos, sin, mask, kv_cache, use_cache) |
|
|
x = residual + attn_out |
|
|
|
|
|
|
|
|
x = x + self.ffn(self.ln2(x)) |
|
|
|
|
|
return x, new_cache |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GPT300M(nn.Module): |
|
|
""" |
|
|
GPT-300M: A 300-million parameter autoregressive language model. |
|
|
|
|
|
Architecture: |
|
|
Token Embedding β [Transformer Block Γ 24] β RMSNorm β LM Head |
|
|
|
|
|
Each Transformer Block: |
|
|
RMSNorm β Multi-Head Attention (+ RoPE) β Residual |
|
|
β RMSNorm β Feed-Forward (GELU) β Residual |
|
|
""" |
|
|
|
|
|
def __init__(self, config: GPT300MConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.token_emb = nn.Embedding(config.vocab_size, config.d_model) |
|
|
self.drop = nn.Dropout(config.dropout) |
|
|
|
|
|
|
|
|
if config.rope: |
|
|
self.rotary = RotaryEmbedding( |
|
|
config.head_dim, config.max_seq_len, config.rope_theta |
|
|
) |
|
|
else: |
|
|
self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
TransformerBlock(config, layer_idx=i) |
|
|
for i in range(config.n_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.ln_f = RMSNorm(config.d_model, eps=config.norm_eps) |
|
|
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
if config.tie_weights: |
|
|
self.lm_head.weight = self.token_emb.weight |
|
|
|
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
for pn, p in self.named_parameters(): |
|
|
if pn.endswith("out_proj.weight") or pn.endswith("down_proj.weight"): |
|
|
nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layers)) |
|
|
|
|
|
def _init_weights(self, module: nn.Module): |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
targets: Optional[torch.Tensor] = None, |
|
|
kv_caches: Optional[list] = None, |
|
|
use_cache: bool = False, |
|
|
position_offset: int = 0, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[list]]: |
|
|
""" |
|
|
Forward pass. |
|
|
|
|
|
Args: |
|
|
input_ids: [B, T] token indices |
|
|
targets: [B, T] target token indices for loss computation |
|
|
kv_caches: List of KV-cache tuples, one per layer |
|
|
use_cache: Whether to return updated KV-caches |
|
|
position_offset: Offset for position embeddings (for KV-cache generation) |
|
|
|
|
|
Returns: |
|
|
logits: [B, T, vocab_size] |
|
|
loss: scalar loss if targets provided, else None |
|
|
new_caches: Updated KV-caches if use_cache=True |
|
|
""" |
|
|
B, T = input_ids.shape |
|
|
assert T <= self.config.max_seq_len, ( |
|
|
f"Sequence length {T} exceeds max {self.config.max_seq_len}" |
|
|
) |
|
|
|
|
|
|
|
|
x = self.token_emb(input_ids) |
|
|
|
|
|
|
|
|
if self.config.rope: |
|
|
cos, sin = self.rotary(T, offset=position_offset) |
|
|
else: |
|
|
positions = torch.arange(position_offset, position_offset + T, device=input_ids.device) |
|
|
x = x + self.pos_emb(positions) |
|
|
cos, sin = None, None |
|
|
|
|
|
x = self.drop(x) |
|
|
|
|
|
|
|
|
new_caches = [] if use_cache else None |
|
|
for i, layer in enumerate(self.layers): |
|
|
cache_i = kv_caches[i] if kv_caches is not None else None |
|
|
x, new_cache = layer(x, cos, sin, kv_cache=cache_i, use_cache=use_cache) |
|
|
if use_cache: |
|
|
new_caches.append(new_cache) |
|
|
|
|
|
|
|
|
x = self.ln_f(x) |
|
|
logits = self.lm_head(x) |
|
|
|
|
|
|
|
|
loss = None |
|
|
if targets is not None: |
|
|
loss = F.cross_entropy( |
|
|
logits.view(-1, self.config.vocab_size), |
|
|
targets.view(-1), |
|
|
ignore_index=self.config.pad_token_id, |
|
|
) |
|
|
|
|
|
return logits, loss, new_caches |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
max_new_tokens: int = 256, |
|
|
temperature: float = 0.7, |
|
|
top_k: int = 50, |
|
|
top_p: float = 0.9, |
|
|
repetition_penalty: float = 1.1, |
|
|
eos_token_id: Optional[int] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Autoregressive generation with KV-cache. |
|
|
|
|
|
Args: |
|
|
input_ids: [B, T] prompt token IDs |
|
|
max_new_tokens: Maximum number of tokens to generate |
|
|
temperature: Sampling temperature |
|
|
top_k: Top-k sampling |
|
|
top_p: Nucleus sampling threshold |
|
|
repetition_penalty: Penalty for repeated tokens |
|
|
eos_token_id: Stop generation when this token is produced |
|
|
|
|
|
Returns: |
|
|
[B, T + max_new_tokens] generated token IDs |
|
|
""" |
|
|
self.eval() |
|
|
B, T = input_ids.shape |
|
|
device = input_ids.device |
|
|
|
|
|
|
|
|
logits, _, kv_caches = self.forward(input_ids, use_cache=True) |
|
|
|
|
|
generated = input_ids |
|
|
all_token_ids = input_ids.tolist()[0] if B == 1 else [] |
|
|
|
|
|
for step in range(max_new_tokens): |
|
|
|
|
|
next_logits = logits[:, -1, :] |
|
|
|
|
|
|
|
|
if repetition_penalty != 1.0 and B == 1: |
|
|
for token_id in set(all_token_ids): |
|
|
if next_logits[0, token_id] > 0: |
|
|
next_logits[0, token_id] /= repetition_penalty |
|
|
else: |
|
|
next_logits[0, token_id] *= repetition_penalty |
|
|
|
|
|
|
|
|
if temperature > 0: |
|
|
next_logits = next_logits / temperature |
|
|
|
|
|
|
|
|
if top_k > 0: |
|
|
topk_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1))) |
|
|
next_logits[next_logits < topk_vals[:, -1:]] = float("-inf") |
|
|
|
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_logits, sorted_idx = torch.sort(next_logits, descending=True) |
|
|
cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
sorted_mask = cumprobs - F.softmax(sorted_logits, dim=-1) >= top_p |
|
|
sorted_logits[sorted_mask] = float("-inf") |
|
|
next_logits = sorted_logits.scatter(1, sorted_idx, sorted_logits) |
|
|
|
|
|
probs = F.softmax(next_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
else: |
|
|
|
|
|
next_token = next_logits.argmax(dim=-1, keepdim=True) |
|
|
|
|
|
generated = torch.cat([generated, next_token], dim=1) |
|
|
|
|
|
if B == 1: |
|
|
all_token_ids.append(next_token.item()) |
|
|
|
|
|
|
|
|
if eos_token_id is not None and next_token.item() == eos_token_id: |
|
|
break |
|
|
|
|
|
|
|
|
position_offset = generated.size(1) - 1 |
|
|
logits, _, kv_caches = self.forward( |
|
|
next_token, |
|
|
kv_caches=kv_caches, |
|
|
use_cache=True, |
|
|
position_offset=position_offset, |
|
|
) |
|
|
|
|
|
return generated |
|
|
|
|
|
def count_parameters(self, trainable_only: bool = True) -> int: |
|
|
"""Count model parameters.""" |
|
|
if trainable_only: |
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
return sum(p.numel() for p in self.parameters()) |
|
|
|
|
|
def model_summary(self) -> str: |
|
|
"""Print a human-readable model summary.""" |
|
|
total = self.count_parameters(trainable_only=False) |
|
|
trainable = self.count_parameters(trainable_only=True) |
|
|
lines = [ |
|
|
"=" * 60, |
|
|
" GPT-300M Model Summary", |
|
|
"=" * 60, |
|
|
f" Total parameters: {total:>15,}", |
|
|
f" Trainable parameters: {trainable:>15,}", |
|
|
f" d_model: {self.config.d_model:>15}", |
|
|
f" n_heads: {self.config.n_heads:>15}", |
|
|
f" n_layers: {self.config.n_layers:>15}", |
|
|
f" d_ff: {self.config.d_ff:>15}", |
|
|
f" vocab_size: {self.config.vocab_size:>15}", |
|
|
f" max_seq_len: {self.config.max_seq_len:>15}", |
|
|
f" RoPE: {'Yes':>15}", |
|
|
f" Weight tying: {'Yes' if self.config.tie_weights else 'No':>15}", |
|
|
f" Flash Attention: {'Yes' if self.layers[0].attn.flash_attn else 'No':>15}", |
|
|
"=" * 60, |
|
|
] |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
from config import gpt_tiny |
|
|
|
|
|
|
|
|
cfg = gpt_tiny() |
|
|
model = GPT300M(cfg) |
|
|
print(model.model_summary()) |
|
|
|
|
|
|
|
|
x = torch.randint(0, cfg.vocab_size, (2, 32)) |
|
|
targets = torch.randint(0, cfg.vocab_size, (2, 32)) |
|
|
logits, loss, _ = model(x, targets=targets) |
|
|
print(f"\nForward pass OK: logits={logits.shape}, loss={loss.item():.4f}") |
|
|
|
|
|
|
|
|
prompt = torch.randint(0, cfg.vocab_size, (1, 8)) |
|
|
gen = model.generate(prompt, max_new_tokens=16, temperature=0.8) |
|
|
print(f"Generation OK: {gen.shape}") |
|
|
|