Spaces:
Sleeping
Sleeping
| # model/transformer_explained.py | |
| """ | |
| Tiny Transformer language model (educational). | |
| Components: | |
| - PositionalEncoding: sinusoidal positional encodings (buffered) | |
| - MultiHeadSelfAttention: returns attn weights optionally | |
| - FeedForward: MLP with GELU | |
| - TransformerBlock: attention + add&norm + FFN + add&norm | |
| - TinyTransformerLM: token embeddings, pos enc, stacked blocks, LM head | |
| """ | |
| import math | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class PositionalEncoding(nn.Module): | |
| """Sinusoidal positional encoding as in "Attention is All You Need". | |
| Stored as a buffer (not learned). Adds positional encodings to token embeddings. | |
| """ | |
| def __init__(self, d_model: int, max_len: int = 2048): | |
| super().__init__() | |
| pe = torch.zeros(max_len, d_model) # (max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (max_len, 1) | |
| div_term = torch.exp( | |
| torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) | |
| ) # (d_model/2,) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0) # (1, max_len, d_model) | |
| self.register_buffer("pe", pe) # not a parameter | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| x: (batch, seq_len, d_model) | |
| returns: x + pe[:, :seq_len, :] | |
| """ | |
| seq_len = x.size(1) | |
| return x + self.pe[:, :seq_len, :].to(x.device) | |
| class MultiHeadSelfAttention(nn.Module): | |
| """ | |
| Multi-head self-attention. | |
| Optionally returns attention weights for visualization. | |
| Input shapes: | |
| x: (batch, seq_len, d_model) | |
| Output: | |
| out: (batch, seq_len, d_model) | |
| Optional: | |
| attn: (batch, num_heads, seq_len, seq_len) | |
| """ | |
| def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0): | |
| super().__init__() | |
| assert d_model % num_heads == 0, "d_model must be divisible by num_heads" | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.d_k = d_model // num_heads | |
| # single linear for qkv then split | |
| self.qkv_proj = nn.Linear(d_model, d_model * 3, bias=False) | |
| self.out_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.attn_dropout = nn.Dropout(dropout) | |
| self.softmax = nn.Softmax(dim=-1) | |
| def forward( | |
| self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_attn: bool = False | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """ | |
| x: (batch, seq_len, d_model) | |
| mask: (batch, 1, seq_len, seq_len) or (batch, seq_len) causal mask etc. | |
| return_attn: if True, also return attention weights | |
| """ | |
| B, S, D = x.shape | |
| # project and split into q,k,v | |
| qkv = self.qkv_proj(x) # (B, S, 3*D) | |
| qkv = qkv.view(B, S, 3, self.num_heads, self.d_k) | |
| q, k, v = qkv.unbind(dim=2) # each: (B, S, num_heads, d_k) | |
| # transpose to (B, num_heads, S, d_k) | |
| q = q.transpose(1, 2) | |
| k = k.transpose(1, 2) | |
| v = v.transpose(1, 2) | |
| # scaled dot-product attention | |
| # attn_scores: (B, num_heads, S, S) | |
| attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) | |
| if mask is not None: | |
| # mask should be broadcastable to (B, num_heads, S, S) | |
| attn_scores = attn_scores.masked_fill(mask == 0, float("-inf")) | |
| attn = self.softmax(attn_scores) # (B, num_heads, S, S) | |
| attn = self.attn_dropout(attn) | |
| # attn @ v -> (B, num_heads, S, d_k) | |
| out = torch.matmul(attn, v) | |
| # transpose & combine heads -> (B, S, D) | |
| out = out.transpose(1, 2).contiguous().view(B, S, D) | |
| out = self.out_proj(out) # final linear | |
| if return_attn: | |
| return out, attn | |
| return out, None | |
| class FeedForward(nn.Module): | |
| """Position-wise feed-forward network.""" | |
| def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(d_model, d_ff), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(d_ff, d_model), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.net(x) | |
| class TransformerBlock(nn.Module): | |
| """A single Transformer block: MHSA -> Add&Norm -> FFN -> Add&Norm""" | |
| def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(d_model) | |
| self.attn = MultiHeadSelfAttention(d_model, num_heads, dropout) | |
| self.ln2 = nn.LayerNorm(d_model) | |
| self.ff = FeedForward(d_model, d_ff, dropout) | |
| def forward( | |
| self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_attn: bool = False | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| # Pre-norm style: ln -> attn -> add | |
| z = self.ln1(x) | |
| attn_out, attn_weights = self.attn(z, mask=mask, return_attn=return_attn) | |
| x = x + attn_out | |
| # FFN | |
| z2 = self.ln2(x) | |
| ff_out = self.ff(z2) | |
| x = x + ff_out | |
| if return_attn: | |
| return x, attn_weights | |
| return x, None | |
| class TinyTransformerLM(nn.Module): | |
| """ | |
| Tiny Transformer language model for educational training/experiments. | |
| Not tokenizer-dependent; expects token ids. | |
| """ | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| d_model: int = 256, | |
| n_layers: int = 4, | |
| num_heads: int = 4, | |
| d_ff: int = 1024, | |
| max_len: int = 512, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.tok_emb = nn.Embedding(vocab_size, d_model) | |
| self.pos_enc = PositionalEncoding(d_model, max_len=max_len) | |
| self.layers = nn.ModuleList( | |
| [TransformerBlock(d_model, num_heads, d_ff, dropout) for _ in range(n_layers)] | |
| ) | |
| self.ln_f = nn.LayerNorm(d_model) | |
| self.head = nn.Linear(d_model, vocab_size, bias=False) # logits head | |
| def forward( | |
| self, input_ids: torch.LongTensor, mask: Optional[torch.Tensor] = None, return_attn_layer: Optional[int] = None | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """ | |
| input_ids: (B, S) | |
| returns: logits (B, S, vocab_size) | |
| if return_attn_layer is an int, it will return attention weights from that layer (heads) | |
| """ | |
| B, S = input_ids.shape | |
| x = self.tok_emb(input_ids) # (B, S, d_model) | |
| x = self.pos_enc(x) | |
| attn_weights = None | |
| for idx, layer in enumerate(self.layers): | |
| if return_attn_layer is not None and idx == return_attn_layer: | |
| x, attn_weights = layer(x, mask=mask, return_attn=True) | |
| else: | |
| x, _ = layer(x, mask=mask, return_attn=False) | |
| x = self.ln_f(x) | |
| logits = self.head(x) # (B, S, vocab_size) | |
| return logits, attn_weights | |