| """ |
| model/model.py |
| |
| SLLM — Small Language Model (decoder-only Transformer). |
| |
| Full architecture: |
| tokens (B, T) |
| -> Embedding (vocab_size -> d_model) |
| -> N x TransformerBlock (attention + FFN) |
| -> Final RMSNorm |
| -> LM Head (Linear d_model -> vocab_size) <- weight-TIED to embedding |
| |
| Weight tying: |
| The embedding matrix and the LM head output matrix share the same weights. |
| - Halves memory for the embedding/output layers. |
| - A standard practice since GPT-2 (Press & Wolf, 2016). |
| |
| Weight initialization: |
| - Embeddings: std=0.02 (GPT-2 convention) |
| - Linear layers: std=0.02 |
| - Output projections (attn.o_proj, mlp.down): std = 0.02/sqrt(2*n_layers) |
| - Scaled down per GPT-2/NanoGPT: at initialization, the residual |
| stream grows as sqrt(n_layers), so we scale residual contributions down. |
| |
| Forward: |
| Returns logits (B, T, vocab_size). |
| Loss is computed externally in the training loop for flexibility. |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| from torch.utils.checkpoint import checkpoint |
|
|
| from model.config import ModelConfig |
| from model.norm import RMSNorm |
| from model.block import TransformerBlock |
|
|
|
|
| class SLLM(nn.Module): |
|
|
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| self.config = config |
|
|
| |
| self.token_emb = nn.Embedding(config.vocab_size, config.d_model) |
|
|
| |
| self.blocks = nn.ModuleList([ |
| TransformerBlock(config) for _ in range(config.n_layers) |
| ]) |
|
|
| |
| self.norm = RMSNorm(config.d_model) |
|
|
| |
| |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
| |
| |
| self.lm_head.weight = self.token_emb.weight |
|
|
| |
| |
| self._gradient_checkpointing = False |
|
|
| |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module: nn.Module): |
| """ |
| Custom weight initialization. |
| - Normal(0, 0.02) for Linear and Embedding |
| - Scaled residual projections: std *= 1/sqrt(2 * n_layers) |
| """ |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| |
| |
| if isinstance(module, nn.Linear): |
| if getattr(module, '_is_residual', False): |
| scale = 0.02 / math.sqrt(2 * self.config.n_layers) |
| nn.init.normal_(module.weight, mean=0.0, std=scale) |
|
|
| def _mark_residual_projections(self): |
| """ |
| Mark output projections so _init_weights can scale them. |
| Called after __init__ to tag the specific layers. |
| """ |
| for block in self.blocks: |
| block.attn.o_proj._is_residual = True |
| block.mlp.down._is_residual = True |
| self.apply(self._init_weights) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| targets: torch.Tensor = None, |
| ): |
| """ |
| Args: |
| input_ids : (B, T) — integer token IDs |
| targets : (B, T) — optional, for loss computation |
| |
| Returns: |
| logits : (B, T, vocab_size) |
| loss : scalar CrossEntropy loss if targets given, else None |
| """ |
| B, T = input_ids.shape |
| assert T <= self.config.context_length, ( |
| f"Sequence length {T} exceeds context_length {self.config.context_length}" |
| ) |
|
|
| |
| x = self.token_emb(input_ids) |
|
|
| |
| for block in self.blocks: |
| if self._gradient_checkpointing and self.training: |
| |
| |
| x = checkpoint(block, x, use_reentrant=False) |
| else: |
| x = block(x) |
|
|
| |
| x = self.norm(x) |
|
|
| |
| logits = self.lm_head(x) |
|
|
| |
| loss = None |
| if targets is not None: |
| |
| loss = nn.functional.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| targets.view(-1), |
| ) |
|
|
| return logits, loss |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids: torch.Tensor, |
| max_new_tokens: int, |
| temperature: float = 1.0, |
| top_k: int = None, |
| ) -> torch.Tensor: |
| """ |
| Autoregressive text generation (greedy or top-k sampling). |
| |
| Args: |
| input_ids : (B, T) prompt tokens |
| max_new_tokens : number of tokens to generate |
| temperature : softmax temperature (1.0 = neutral, <1 = sharper) |
| top_k : if set, sample from top-k tokens only |
| |
| Returns: |
| (B, T + max_new_tokens) token IDs |
| """ |
| self.eval() |
| for _ in range(max_new_tokens): |
|
|
| |
| ctx = input_ids |
| if ctx.shape[1] > self.config.context_length: |
| ctx = ctx[:, -self.config.context_length:] |
|
|
| |
| logits, _ = self(ctx) |
| logits = logits[:, -1, :] / temperature |
|
|
| |
| if top_k is not None: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = float('-inf') |
|
|
| |
| probs = torch.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
|
|
| input_ids = torch.cat([input_ids, next_token], dim=1) |
|
|
| return input_ids |
|
|
| def enable_gradient_checkpointing(self): |
| """ |
| Enables gradient checkpointing to reduce VRAM usage. |
| Recomputes activations during the backward pass instead of |
| storing them — trades ~30% more compute for ~40% less memory. |
| Essential for fitting 100M+ models on 4GB VRAM. |
| """ |
| self._gradient_checkpointing = True |
|
|
| def count_params(self, non_embedding: bool = False) -> int: |
| """ |
| Returns parameter count. |
| |
| Args: |
| non_embedding: if True, exclude embedding parameters |
| (common in LLM reporting since embeddings scale |
| with vocab size and not model capacity) |
| """ |
| total = sum(p.numel() for p in self.parameters()) |
| if non_embedding: |
| total -= self.token_emb.weight.numel() |
| return total |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| from model.config import SLLM_100M, SLLM_150M |
|
|
| for name, cfg in [("SLLM-100M", SLLM_100M), ("SLLM-150M", SLLM_150M)]: |
| model = SLLM(cfg) |
|
|
| total = model.count_params() |
| non_emb = model.count_params(non_embedding=True) |
| print(f"{name}") |
| print(f" total params : {total/1e6:.1f}M") |
| print(f" non-embedding params : {non_emb/1e6:.1f}M") |
| print(f" embedding params : {(total-non_emb)/1e6:.1f}M") |
|
|
| |
| B, T = 2, 64 |
| ids = torch.randint(0, cfg.vocab_size, (B, T)) |
| targets = torch.randint(0, cfg.vocab_size, (B, T)) |
|
|
| logits, loss = model(ids, targets) |
| print(f" logits shape : {logits.shape}") |
| print(f" loss : {loss.item():.4f}") |
| print() |
|
|