| """ |
| transformer.py |
| ============== |
| Transformer Decoder-only architecture implemented from scratch in PyTorch. |
| |
| This module is part of the project: |
| "A bilingual PT+EN LLM with BPE tokenizer and training loop |
| implemented from scratch, with didactic and documented code" |
| |
| Author : AndrΓ© Costa |
| License : MIT |
| |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| THEORETICAL BACKGROUND |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| |
| The Transformer architecture (Vaswani et al., 2017) |
| ------------------------------------------------- |
| The Transformer originally emerged as an encoder-decoder model for |
| machine translation. For generative language models, we use only |
| the decoder half β called "decoder-only" or "causal LM". |
| |
| This is the architecture used by GPT-2, GPT-3, GPT-4, LLaMA, Mistral, |
| and virtually all modern LLMs. |
| |
| Why decoder-only for text generation? |
| -------------------------------------------- |
| The decoder-only uses causal attention (also called masked attention): |
| each token can only "see" previous tokens, never future ones. |
| This allows training the model to predict the next token β the standard |
| pre-training objective (Language Modeling or LM loss). |
| |
| Entrada : [t1, t2, t3, t4] |
| SaΓda : [t2, t3, t4, t5] β cada posiΓ§Γ£o prevΓͺ o prΓ³ximo token |
| |
| Overview of the implemented architecture |
| ----------------------------------------- |
| Our implementation incorporates modern improvements over the original |
| 2017 Transformer: |
| |
| 1. RMSNorm (Zhang & Sennrich, 2019) instead of LayerNorm |
| β More efficient: no mean computation, normalizes variance only |
| |
| 2. RoPE β Rotary Position Embedding (Su et al., 2021) instead of |
| absolute positional embeddings |
| β Better generalization to sequences longer than those seen in training |
| |
| 3. SwiGLU (Shazeer, 2020) instead of FFN with ReLU |
| β Gated activation learns to "filter" information adaptively |
| |
| 4. Pre-norm (norm before attention/FFN) instead of post-norm |
| β More stable training, healthier gradients |
| |
| These are exactly the choices made by LLaMA (Touvron et al., 2023), |
| which have become the industry standard. |
| |
| Data flow through the model: |
| tokens (B, T) |
| β nn.Embedding |
| x (B, T, d_model) |
| β N Γ TransformerBlock |
| x (B, T, d_model) |
| β RMSNorm final |
| x (B, T, d_model) |
| β Linear (lm_head) |
| logits (B, T, vocab_size) |
| |
| where B = batch size, T = seq_len, d_model = model dimension |
| |
| ReferΓͺncias: |
| - Vaswani, A. et al. (2017). Attention is all you need. NeurIPS. |
| - Zhang, B., & Sennrich, R. (2019). Root mean square layer normalization. |
| - Su, J. et al. (2021). RoFormer: Enhanced transformer with rotary |
| position embedding. arXiv:2104.09864. |
| - Shazeer, N. (2020). GLU variants improve transformer. arXiv:2002.05202. |
| - Touvron, H. et al. (2023). LLaMA: Open and efficient foundation |
| language models. arXiv:2302.13971. |
| """ |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class ModelConfig: |
| """ |
| Model architecture hyperparameters. |
| |
| Centralizing configuration in a dataclass allows: |
| - Saving and loading the architecture alongside weights |
| - Reproducing experiments exactly |
| - Varying model sizes without changing the code |
| |
| Nomenclature follows literature conventions: |
| d_model = embedding space dimension (also called |
| "hidden size" or "model dimension") |
| n_heads = number of attention heads |
| n_layers = number of stacked Transformer blocks |
| d_ff = internal FFN (feed-forward network) dimension, |
| typically 4 Γ d_model (original) or 8/3 Γ d_model (SwiGLU) |
| |
| Pre-defined configurations (for reference): |
| Tiny (~15M): d_model=256, n_heads=4, n_layers=4, d_ff=1024 |
| Small (~85M): d_model=512, n_heads=8, n_layers=8, d_ff=2048 |
| Base (~310M): d_model=768, n_heads=12, n_layers=12, d_ff=3072 |
| """ |
| |
| vocab_size: int = 16384 |
| seq_len: int = 512 |
|
|
| |
| d_model: int = 512 |
| n_heads: int = 8 |
| n_layers: int = 8 |
| d_ff: int = 1536 |
|
|
| |
| dropout: float = 0.1 |
|
|
| |
| use_flash: bool = True |
|
|
| def __post_init__(self): |
| """Validate hyperparameter consistency.""" |
| assert self.d_model % self.n_heads == 0, ( |
| f"d_model ({self.d_model}) must be divisible by " |
| f"n_heads ({self.n_heads})" |
| ) |
| |
| self.d_head = self.d_model // self.n_heads |
|
|
| @property |
| def n_params(self) -> int: |
| """ |
| Estimate the number of model parameters. |
| |
| Useful for checking whether the model fits in available VRAM before |
| instantiation. The estimate is approximate (ignores bias and buffers). |
| |
| Main components: |
| - Embedding: vocab_size Γ d_model |
| - Per block: attention (4 Γ d_modelΒ²) + FFN (3 Γ d_model Γ d_ff) |
| - LM head: d_model Γ vocab_size (usually tied with embedding) |
| """ |
| embed = self.vocab_size * self.d_model |
| attn = self.n_layers * 4 * (self.d_model ** 2) |
| ffn = self.n_layers * 3 * self.d_model * self.d_ff |
| lm_head = self.d_model * self.vocab_size |
| return embed + attn + ffn + lm_head |
|
|
|
|
| |
| |
| |
|
|
| class RMSNorm(nn.Module): |
| """ |
| Root Mean Square Layer Normalization (Zhang & Sennrich, 2019). |
| |
| The original LayerNorm normalizes by mean and standard deviation: |
| LayerNorm(x) = (x - ΞΌ) / (Ο + Ξ΅) * Ξ³ + Ξ² |
| |
| RMSNorm simplifies: does not subtract the mean (ΞΌ = 0 assumed), |
| normalizes only by RMS (root mean square): |
| RMSNorm(x) = x / RMS(x) * Ξ³ |
| RMS(x) = sqrt(mean(xΒ²) + Ξ΅) |
| |
| Advantages: |
| - ~15% faster than LayerNorm (no mean computation) |
| - No Ξ² (bias) parameter, slightly reducing parameter count |
| - Same empirical quality in LLMs (used in LLaMA, Mistral, etc.) |
| |
| Args: |
| d_model: Dimension of the vector to normalize. |
| eps: Numerical stability constant (avoids division by zero). |
| """ |
|
|
| def __init__(self, d_model: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| |
| self.weight = nn.Parameter(torch.ones(d_model)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Apply RMSNorm to tensor x. |
| |
| Args: |
| x: Tensor of shape (..., d_model). |
| |
| Returns: |
| Normalized tensor of same shape as x. |
| """ |
| |
| |
| rms = x.float().pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt() |
|
|
| |
| x_norm = (x.float() / rms).to(x.dtype) |
|
|
| |
| return x_norm * self.weight |
|
|
|
|
| |
| |
| |
|
|
| def precompute_rope_freqs(d_head: int, seq_len: int, base: float = 10000.0) -> torch.Tensor: |
| """ |
| Pre-compute complex frequencies for RoPE. |
| |
| RoPE (Su et al., 2021) encodes position by rotating query and key |
| vectors in the complex space. The rotation at position m uses |
| angle ΞΈ_i = m / base^(2i/d), where i indexes the dimension pair. |
| |
| Geometric intuition: |
| - Each pair of dimensions (2i, 2i+1) forms a 2D plane |
| - At position m, we rotate in that plane by m Γ ΞΈ_i |
| - The dot product qΒ·k preserves only the position difference (m-n) |
| - This gives relative position attention automatically |
| |
| Advantages over absolute embeddings: |
| - Generalization to seq_len > training seq_len (extrapolation) |
| - No extra parameters |
| - Attention is naturally sensitive to relative distance |
| |
| Args: |
| d_head: Dimension of each attention head. |
| seq_len: Maximum sequence length. |
| base: Frequency base (10000 is the RoPE original default). |
| |
| Returns: |
| Complex tensor of shape (seq_len, d_head // 2) with the frequencies. |
| """ |
| |
| theta = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head)) |
|
|
| |
| positions = torch.arange(seq_len).float() |
|
|
| |
| |
| freqs = torch.outer(positions, theta) |
|
|
| |
| freqs_complex = torch.polar(torch.ones_like(freqs), freqs) |
|
|
| return freqs_complex |
|
|
|
|
| def apply_rope(x: torch.Tensor, freqs_complex: torch.Tensor) -> torch.Tensor: |
| """ |
| Apply Rotary Position Embedding to a query or key tensor. |
| |
| Application works in 3 steps: |
| 1. Interpret consecutive dimension pairs as complex numbers |
| 2. Multiply by the rotation factor e^(i Γ m Γ ΞΈ) |
| 3. Convert back to real tensor |
| |
| Args: |
| x: Tensor of shape (B, T, n_heads, d_head). |
| freqs_complex: Pre-computed frequencies of shape (T, d_head // 2). |
| |
| Returns: |
| Rotated tensor of same shape as x. |
| """ |
| B, T, H, D = x.shape |
|
|
| |
| |
| x_complex = torch.view_as_complex(x.float().reshape(B, T, H, D // 2, 2)) |
|
|
| |
| freqs = freqs_complex.unsqueeze(0).unsqueeze(2) |
|
|
| |
| x_rotated = x_complex * freqs |
|
|
| |
| x_out = torch.view_as_real(x_rotated).reshape(B, T, H, D) |
|
|
| return x_out.to(x.dtype) |
|
|
|
|
| |
| |
| |
|
|
| class CausalSelfAttention(nn.Module): |
| """ |
| Causal (masked) multi-head attention with RoPE. |
| |
| Attention (Vaswani et al., 2017) computes: |
| Attention(Q, K, V) = softmax(QK^T / βd_head) Γ V |
| |
| "Causal" means we add a mask that prevents each position from |
| attending to future positions. This is essential for autoregressive |
| training (predicting the next token). |
| |
| "Multi-head" means we repeat the process n_heads times in different |
| subspaces, then concatenate: |
| MultiHead(Q,K,V) = Concat(head_1, ..., head_h) Γ W_O |
| |
| Each head learns to attend to different types of relationships: |
| some heads learn syntax, others semantics, etc. |
| |
| Detailed implementation: |
| 1. Project x into Q, K, V via linear transformations |
| 2. Apply RoPE to Q and K (not V) |
| 3. Compute attention with causal mask |
| 4. Project output back to d_model |
| |
| Args: |
| config: Model configuration. |
| """ |
|
|
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| self.config = config |
| self.n_heads = config.n_heads |
| self.d_head = config.d_head |
| self.d_model = config.d_model |
|
|
| |
| |
| |
| self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model, bias=False) |
|
|
| |
| self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) |
|
|
| |
| self.attn_dropout = nn.Dropout(config.dropout) |
|
|
| |
| |
| |
| mask = torch.tril(torch.ones(config.seq_len, config.seq_len)) |
| self.register_buffer("causal_mask", mask.view(1, 1, config.seq_len, config.seq_len)) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| freqs_complex: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Compute causal multi-head attention. |
| |
| Args: |
| x: Input tensor, shape (B, T, d_model). |
| freqs_complex: Pre-computed RoPE frequencies, shape (T, d_head//2). |
| |
| Returns: |
| Output tensor, shape (B, T, d_model). |
| """ |
| B, T, C = x.shape |
|
|
| |
| |
| qkv = self.qkv_proj(x) |
|
|
| |
| q, k, v = qkv.split(self.d_model, dim=-1) |
|
|
| |
| q = q.view(B, T, self.n_heads, self.d_head) |
| k = k.view(B, T, self.n_heads, self.d_head) |
| v = v.view(B, T, self.n_heads, self.d_head) |
|
|
| |
| |
| q = apply_rope(q, freqs_complex) |
| k = apply_rope(k, freqs_complex) |
|
|
| |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| |
| if self.config.use_flash and hasattr(F, "scaled_dot_product_attention"): |
| |
| |
| |
| y = F.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=None, |
| dropout_p=self.attn_dropout.p if self.training else 0.0, |
| is_causal=True, |
| ) |
| else: |
| |
| |
| scale = 1.0 / math.sqrt(self.d_head) |
| scores = torch.matmul(q, k.transpose(-2, -1)) * scale |
|
|
| |
| |
| mask = self.causal_mask[:, :, :T, :T] |
| scores = scores.masked_fill(mask == 0, float("-inf")) |
|
|
| |
| attn_weights = F.softmax(scores, dim=-1) |
| attn_weights = self.attn_dropout(attn_weights) |
|
|
| |
| y = torch.matmul(attn_weights, v) |
|
|
| |
| |
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
|
| |
| return self.out_proj(y) |
|
|
|
|
| |
| |
| |
|
|
| class SwiGLUFFN(nn.Module): |
| """ |
| Feed-Forward Network with SwiGLU activation (Shazeer, 2020). |
| |
| The original Transformer FFN uses two linear layers with ReLU: |
| FFN(x) = max(0, xW_1 + b_1)W_2 + b_2 |
| |
| SwiGLU (Swish-Gated Linear Unit) uses a learnable "gate": |
| SwiGLU(x) = (xW_1 β Swish(xW_gate)) Γ W_2 |
| |
| Where β is element-wise multiplication and Swish(x) = x Γ Ο(x). |
| |
| The W_gate learns to filter which activations are relevant, |
| giving the model more expressive capacity at similar cost. |
| |
| Why 3 matrices instead of 2? |
| SwiGLU uses 3 projections (W_1, W_gate, W_2) instead of 2. |
| To maintain the same parameter count as the original FFN |
| (which uses d_ff = 4 Γ d_model), we use d_ff β 8/3 Γ d_model. |
| In practice, we round to multiples of 256 for efficiency. |
| |
| Args: |
| config: Model configuration. |
| """ |
|
|
| def __init__(self, config: ModelConfig): |
| super().__init__() |
|
|
| |
| |
| self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=False) |
| self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=False) |
|
|
| |
| self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False) |
|
|
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Apply the FFN with SwiGLU. |
| |
| Args: |
| x: Tensor of shape (B, T, d_model). |
| |
| Returns: |
| Tensor of shape (B, T, d_model). |
| """ |
| |
| |
| |
| gate = F.silu(self.gate_proj(x)) |
| up = self.up_proj(x) |
|
|
| |
| hidden = self.dropout(gate * up) |
| return self.down_proj(hidden) |
|
|
|
|
| |
| |
| |
|
|
| class TransformerBlock(nn.Module): |
| """ |
| Full Transformer block with pre-norm. |
| |
| Each block consists of two sub-modules with residual connections: |
| 1. Self-Attention (with RoPE and causal mask) |
| 2. Feed-Forward Network (SwiGLU) |
| |
| Pre-norm vs Post-norm: |
| The original Transformer (Vaswani et al., 2017) uses post-norm: |
| x = LayerNorm(x + SubLayer(x)) |
| |
| Modern LLMs use pre-norm (also called "pre-LN"): |
| x = x + SubLayer(LayerNorm(x)) |
| |
| Pre-norm has more stable gradients during training, since |
| normalization happens before non-linear transformations. |
| This allows training deeper networks without extensive warm-up. |
| |
| Residual connections (He et al., 2016): |
| The addition x + SubLayer(x) creates a "shortcut" that allows |
| gradients to flow directly through layers, independent of |
| transformations. Fundamental for training deep networks. |
| |
| Args: |
| config: Model configuration. |
| """ |
|
|
| def __init__(self, config: ModelConfig): |
| super().__init__() |
|
|
| |
| self.norm1 = RMSNorm(config.d_model) |
|
|
| |
| self.attn = CausalSelfAttention(config) |
|
|
| |
| self.norm2 = RMSNorm(config.d_model) |
|
|
| |
| self.ffn = SwiGLUFFN(config) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| freqs_complex: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Process x through the Transformer block. |
| |
| Args: |
| x: Tensor of shape (B, T, d_model). |
| freqs_complex: RoPE frequencies of shape (T, d_head//2). |
| |
| Returns: |
| Tensor of shape (B, T, d_model). |
| """ |
| |
| |
| x = x + self.attn(self.norm1(x), freqs_complex) |
|
|
| |
| x = x + self.ffn(self.norm2(x)) |
|
|
| return x |
|
|
|
|
| |
| |
| |
|
|
| class MiniLM(nn.Module): |
| """ |
| Complete Transformer Decoder-only language model. |
| |
| "MiniLM" is the name given to this project's model. Architecture |
| based on modern best practices (LLaMA-style). |
| |
| Components (in forward pass order): |
| 1. Token Embedding: maps token IDs to dense vectors |
| 2. N Γ TransformerBlock: processes vectors with attention and FFN |
| 3. Final RMSNorm: normalizes before output projection |
| 4. LM Head: projects from d_model to vocab_size (logits) |
| |
| Weight tying: |
| Input embedding and LM head weights are shared (tied weights). |
| This reduces parameter count by ~10-20% without quality loss β |
| used in GPT-2 and LLaMA. |
| Intuition: the embedding learns "what tokens look like", and |
| the LM head learns "which tokens are likely" β similar information. |
| |
| Args: |
| config: Full model configuration. |
| """ |
|
|
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| self.config = config |
|
|
| |
| |
| |
| self.token_emb = nn.Embedding(config.vocab_size, config.d_model) |
|
|
| |
| self.emb_dropout = nn.Dropout(config.dropout) |
|
|
| |
| self.blocks = nn.ModuleList([ |
| TransformerBlock(config) |
| for _ in range(config.n_layers) |
| ]) |
|
|
| |
| self.norm_final = RMSNorm(config.d_model) |
|
|
| |
| |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
| |
| |
| self.lm_head.weight = self.token_emb.weight |
|
|
| |
| |
| |
| freqs = precompute_rope_freqs(config.d_head, config.seq_len) |
| self.register_buffer("freqs_complex", freqs) |
|
|
| |
| self.apply(self._init_weights) |
|
|
| |
| |
| for name, param in self.named_parameters(): |
| if name.endswith(("out_proj.weight", "down_proj.weight")): |
| nn.init.normal_( |
| param, |
| mean=0.0, |
| std=0.02 / math.sqrt(2 * config.n_layers) |
| ) |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| """ |
| Initialize model weights. |
| |
| Follows GPT-2 initialization: |
| - Linear and Embedding layers: Normal(0, 0.02) |
| - Bias (when present): zeros |
| |
| The Normal(0, 0.02) distribution is small enough to keep |
| activations at a reasonable scale at the start of training, |
| avoiding gradient explosion or vanishing. |
| |
| Args: |
| module: Module to initialize (called recursively by apply()). |
| """ |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| targets: Optional[torch.Tensor] = None, |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| """ |
| Model forward pass. |
| |
| Training mode (targets provided): |
| Computes logits AND loss efficiently in a single forward pass. |
| |
| Inference mode (targets=None): |
| Returns only the last position logits. |
| |
| Args: |
| input_ids: Token ID tensor, shape (B, T). |
| targets: Next token IDs, shape (B, T). |
| If provided, computes the cross-entropy loss. |
| |
| Returns: |
| Tuple (logits, loss): |
| logits: shape (B, T, vocab_size) β raw probabilities |
| loss: scalar if targets provided, None otherwise |
| |
| Training example: |
| input_ids = [t1, t2, t3, t4] β input tokens |
| targets = [t2, t3, t4, t5] β next tokens (shift of 1) |
| The model learns: given t1, predict t2; given t1,t2, predict t3; etc. |
| """ |
| B, T = input_ids.shape |
| assert T <= self.config.seq_len, ( |
| f"Sequence of length {T} exceeds seq_len={self.config.seq_len}" |
| ) |
|
|
| |
| |
| x = self.token_emb(input_ids) |
| x = self.emb_dropout(x) |
|
|
| |
| |
| |
| freqs = self.freqs_complex[:T] |
|
|
| |
| for block in self.blocks: |
| x = block(x, freqs) |
|
|
| |
| x = self.norm_final(x) |
|
|
| |
| if targets is not None: |
| |
| |
| logits = self.lm_head(x) |
|
|
| |
| |
| loss = F.cross_entropy( |
| logits.view(-1, self.config.vocab_size), |
| targets.view(-1), |
| ignore_index=-1, |
| ) |
| return logits, loss |
| else: |
| |
| |
| logits = self.lm_head(x[:, -1:, :]) |
| return logits, None |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids: torch.Tensor, |
| max_new_tokens: int = 100, |
| temperature: float = 1.0, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| ) -> torch.Tensor: |
| """ |
| Autoregressive text generation. |
| |
| The generation process works in a loop: |
| 1. Pass the current sequence through the model β next token logits |
| 2. Apply temperature (controls randomness) |
| 3. Apply top-k and/or top-p filters (controls diversity) |
| 4. Sample the next token |
| 5. Append to sequence and repeat |
| |
| Temperature: |
| - T β 0: deterministic generation (always the most probable token) |
| - T = 1: original model distribution |
| - T > 1: more random, more creative (but may be incoherent) |
| |
| Top-k sampling: |
| Keeps only the k most probable tokens before sampling. |
| Prevents very unlikely tokens from being selected. |
| |
| Top-p (nucleus) sampling (Holtzman et al., 2019): |
| Keeps the smallest set of tokens whose cumulative probability |
| β₯ p. Adaptively selects more or fewer tokens depending on |
| the distribution. |
| |
| Args: |
| input_ids: Initial context tokens, shape (1, T). |
| max_new_tokens: How many new tokens to generate. |
| temperature: Randomness control (0.1 to 2.0). |
| top_k: Filter to top-k tokens (e.g., 50). |
| top_p: Nucleus sampling (e.g., 0.9). |
| |
| Returns: |
| Tensor with full sequence (context + generated), shape (1, T+N). |
| """ |
| self.eval() |
|
|
| for _ in range(max_new_tokens): |
| |
| context = input_ids[:, -self.config.seq_len:] |
|
|
| |
| logits, _ = self(context) |
| |
| logits = logits[:, -1, :].squeeze(0) |
|
|
| |
| if temperature != 1.0: |
| logits = logits / temperature |
|
|
| |
| if top_k is not None: |
| top_k = min(top_k, logits.size(-1)) |
| values, _ = torch.topk(logits, top_k) |
| threshold = values[-1] |
| logits = logits.masked_fill(logits < threshold, float("-inf")) |
|
|
| |
| if top_p is not None: |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
| |
| sorted_indices_to_remove = cumulative_probs > top_p |
| |
| sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() |
| sorted_indices_to_remove[0] = False |
|
|
| indices_to_remove = sorted_indices[sorted_indices_to_remove] |
| logits[indices_to_remove] = float("-inf") |
|
|
| |
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1).unsqueeze(0) |
|
|
| |
| input_ids = torch.cat([input_ids, next_token], dim=1) |
|
|
| return input_ids |
|
|
| def count_parameters(self) -> dict[str, int]: |
| """ |
| Count model parameters by component. |
| |
| Useful for verifying the parameter distribution and understanding |
| where model capacity is concentrated. |
| |
| Returns: |
| Dictionary with parameter count per component. |
| """ |
| def count(module): |
| return sum(p.numel() for p in module.parameters()) |
|
|
| return { |
| "token_embedding": count(self.token_emb), |
| "attention_layers": sum(count(b.attn) for b in self.blocks), |
| "ffn_layers": sum(count(b.ffn) for b in self.blocks), |
| "norm_layers": sum(count(b.norm1) + count(b.norm2) for b in self.blocks), |
| "lm_head": 0, |
| "total": count(self), |
| } |
|
|
| def __repr__(self) -> str: |
| params = self.count_parameters() |
| return ( |
| f"MiniLM(\n" |
| f" vocab_size={self.config.vocab_size}, " |
| f"seq_len={self.config.seq_len}\n" |
| f" d_model={self.config.d_model}, " |
| f"n_heads={self.config.n_heads}, " |
| f"n_layers={self.config.n_layers}\n" |
| f" d_ff={self.config.d_ff}, " |
| f"d_head={self.config.d_head}\n" |
| f" params={params['total'] / 1e6:.1f}M\n" |
| f")" |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def estimate_vram(config: ModelConfig, batch_size: int = 8, dtype_bytes: int = 2) -> dict: |
| """ |
| Estimate VRAM usage for training the model. |
| |
| Total training memory has four components: |
| 1. Model parameters |
| 2. Gradients (same size as parameters) |
| 3. Optimizer states (AdamW keeps 2 moments per parameter) |
| 4. Activations (depends on batch size and seq_len) |
| |
| This is a conservative estimate β actual usage may vary. |
| |
| Args: |
| config: Model configuration. |
| batch_size: Training batch size. |
| dtype_bytes: Bytes per parameter (2 for bf16/fp16, 4 for fp32). |
| |
| Returns: |
| Dictionary with GB estimates per component. |
| """ |
| n_params = config.n_params |
|
|
| |
| params_gb = n_params * dtype_bytes / 1e9 |
| grads_gb = params_gb |
|
|
| |
| optimizer_gb = n_params * 8 / 1e9 |
|
|
| |
| |
| activations_per_block = batch_size * config.seq_len * config.d_model * dtype_bytes |
| activations_gb = config.n_layers * activations_per_block / 1e9 |
|
|
| total_gb = params_gb + grads_gb + optimizer_gb + activations_gb |
|
|
| return { |
| "parameters": f"{params_gb:.2f} GB", |
| "gradients": f"{grads_gb:.2f} GB", |
| "optimizer": f"{optimizer_gb:.2f} GB", |
| "activations": f"{activations_gb:.2f} GB", |
| "total_estimate":f"{total_gb:.2f} GB", |
| "n_params": f"{n_params / 1e6:.1f}M", |
| } |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| print("=" * 60) |
| print(" MiniLM Demo") |
| print("=" * 60) |
|
|
| |
| config = ModelConfig( |
| vocab_size=16384, |
| seq_len=512, |
| d_model=512, |
| n_heads=8, |
| n_layers=8, |
| d_ff=1536, |
| dropout=0.1, |
| ) |
|
|
| print(f"\nAvailable configurations:") |
| configs = { |
| "Tiny (~15M)": ModelConfig(d_model=256, n_heads=4, n_layers=4, d_ff=768), |
| "Small (~85M)": ModelConfig(d_model=512, n_heads=8, n_layers=8, d_ff=1536), |
| "Base (~310M)": ModelConfig(d_model=768, n_heads=12, n_layers=12, d_ff=2304), |
| } |
| for name, cfg in configs.items(): |
| print(f" {name}: {cfg.n_params / 1e6:.0f}M params") |
|
|
| print(f"\nInstantiating Small model...") |
| model = MiniLM(config) |
| print(model) |
|
|
| |
| print("\nParameter distribution:") |
| for component, count in model.count_parameters().items(): |
| if count > 0: |
| print(f" {component:<20}: {count / 1e6:.2f}M") |
|
|
| |
| print("\nVRAM estimate (batch=8, bf16):") |
| vram = estimate_vram(config, batch_size=8, dtype_bytes=2) |
| for k, v in vram.items(): |
| print(f" {k:<20}: {v}") |
|
|
| |
| print("\nForward pass test...") |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f" Device: {device}") |
| model = model.to(device) |
|
|
| B, T = 2, 64 |
| input_ids = torch.randint(0, config.vocab_size, (B, T)).to(device) |
| targets = torch.randint(0, config.vocab_size, (B, T)).to(device) |
|
|
| logits, loss = model(input_ids, targets) |
| print(f" Input shape : {input_ids.shape}") |
| print(f" Logits shape : {logits.shape}") |
| print(f" Initial loss : {loss.item():.4f}") |
| print(f" Expected loss: {math.log(config.vocab_size):.4f} (maximum entropy)") |
|
|
| |
| print("\nGeneration test (10 tokens)...") |
| prompt = torch.randint(0, config.vocab_size, (1, 5)).to(device) |
| generated = model.generate(prompt, max_new_tokens=10, temperature=0.8, top_k=50) |
| print(f" Prompt shape : {prompt.shape}") |
| print(f" Generated shape: {generated.shape}") |
| print(f" New tokens : {generated[0, 5:].tolist()}") |
| print("\nForward pass and generation OK.") |
|
|