| | """ |
| | Full transformer: TransformerBlock and top-level LLM model. |
| | Supports pure Transformer and hybrid Mamba-2 + Transformer architectures. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from .config import LMConfig |
| | from .layers import RMSNorm, RotaryEmbedding, SwiGLU |
| | from .attention import MultiHeadAttention |
| | from .mamba_block import Mamba2Block |
| |
|
| | |
| | |
| | |
| | try: |
| | import transformer_engine.pytorch as te |
| | HAS_TE = True |
| | except ImportError: |
| | te = None |
| | HAS_TE = False |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _load_hf_state_dict(path: Path) -> dict[str, torch.Tensor]: |
| | """Load weights from HF safetensors (or pytorch_model.bin fallback).""" |
| | safetensors_path = path / "model.safetensors" |
| | if safetensors_path.exists(): |
| | from safetensors.torch import load_file |
| | return load_file(str(safetensors_path), device="cpu") |
| | bin_path = path / "pytorch_model.bin" |
| | if bin_path.exists(): |
| | return torch.load(bin_path, map_location="cpu", weights_only=True) |
| | raise FileNotFoundError(f"No model.safetensors or pytorch_model.bin in {path}") |
| |
|
| |
|
| | def _convert_hf_to_custom(hf_sd: dict[str, torch.Tensor], config: LMConfig) -> dict[str, torch.Tensor]: |
| | """Convert HuggingFace LlamaForCausalLM state dict to our custom format. |
| | |
| | Key mapping: |
| | HF: model.embed_tokens.weight → embedding.weight |
| | HF: model.layers.{i}.self_attn.q/k/v_proj.weight → layers.{i}.attn.qkv_proj.weight (fused) |
| | HF: model.layers.{i}.self_attn.o_proj.weight → layers.{i}.attn.out_proj.weight |
| | HF: model.layers.{i}.input_layernorm.weight → layers.{i}.attn_norm.weight |
| | HF: model.layers.{i}.mlp.gate_proj.weight → layers.{i}.ffn.gate_proj.weight |
| | HF: model.layers.{i}.mlp.up_proj.weight → layers.{i}.ffn.up_proj.weight |
| | HF: model.layers.{i}.mlp.down_proj.weight → layers.{i}.ffn.down_proj.weight |
| | HF: model.layers.{i}.post_attention_layernorm.weight → layers.{i}.ffn_norm.weight |
| | HF: model.norm.weight → norm.weight |
| | HF: lm_head.weight → lm_head.weight |
| | """ |
| | sd: dict[str, torch.Tensor] = {} |
| |
|
| | sd["embedding.weight"] = hf_sd["model.embed_tokens.weight"] |
| | sd["norm.weight"] = hf_sd["model.norm.weight"] |
| | sd["lm_head.weight"] = hf_sd["lm_head.weight"] |
| |
|
| | for i in range(config.n_layers): |
| | pfx = f"model.layers.{i}" |
| | out = f"layers.{i}" |
| |
|
| | |
| | q = hf_sd[f"{pfx}.self_attn.q_proj.weight"] |
| | k = hf_sd[f"{pfx}.self_attn.k_proj.weight"] |
| | v = hf_sd[f"{pfx}.self_attn.v_proj.weight"] |
| | sd[f"{out}.attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0) |
| |
|
| | sd[f"{out}.attn.out_proj.weight"] = hf_sd[f"{pfx}.self_attn.o_proj.weight"] |
| | sd[f"{out}.attn_norm.weight"] = hf_sd[f"{pfx}.input_layernorm.weight"] |
| |
|
| | sd[f"{out}.ffn.gate_proj.weight"] = hf_sd[f"{pfx}.mlp.gate_proj.weight"] |
| | sd[f"{out}.ffn.up_proj.weight"] = hf_sd[f"{pfx}.mlp.up_proj.weight"] |
| | sd[f"{out}.ffn.down_proj.weight"] = hf_sd[f"{pfx}.mlp.down_proj.weight"] |
| | sd[f"{out}.ffn_norm.weight"] = hf_sd[f"{pfx}.post_attention_layernorm.weight"] |
| |
|
| | return sd |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TransformerBlock(nn.Module): |
| | """Single pre-norm transformer decoder block. |
| | |
| | Layout: |
| | x = x + Attention( RMSNorm(x) ) |
| | x = x + FFN( RMSNorm(x) ) |
| | """ |
| |
|
| | def __init__(self, config: LMConfig) -> None: |
| | super().__init__() |
| | self.attn_norm = RMSNorm(config.d_model) |
| | self.attn = MultiHeadAttention(config) |
| | self._use_fp8 = config.use_fp8 and HAS_TE |
| |
|
| | if self._use_fp8: |
| | |
| | |
| | self.ffn_norm = None |
| | self.ffn = te.LayerNormMLP( |
| | hidden_size=config.d_model, |
| | ffn_hidden_size=config.d_ffn, |
| | bias=config.bias, |
| | activation="swiglu", |
| | normalization="RMSNorm", |
| | ) |
| | else: |
| | self.ffn_norm = RMSNorm(config.d_model) |
| | self.ffn = SwiGLU(config.d_model, config.d_ffn, bias=config.bias) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | cos: torch.Tensor, |
| | sin: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Args: |
| | x: (B, T, C) |
| | cos: (T, head_dim // 2) |
| | sin: (T, head_dim // 2) |
| | |
| | Returns: |
| | (B, T, C) |
| | """ |
| | |
| | x = x + self.attn(self.attn_norm(x), cos, sin) |
| | |
| | if self._use_fp8: |
| | x = x + self.ffn(x) |
| | else: |
| | x = x + self.ffn(self.ffn_norm(x)) |
| | return x |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class LLM(nn.Module): |
| | """Decoder-only transformer language model. |
| | |
| | Features: |
| | - Learned token embeddings with weight tying to the LM head |
| | - Rotary positional embeddings (no learned position embeddings) |
| | - Stack of pre-norm TransformerBlocks |
| | - Final RMSNorm before the LM head |
| | - Optional cross-entropy loss computation (for training) |
| | """ |
| |
|
| | def __init__(self, config: LMConfig) -> None: |
| | super().__init__() |
| | self.config = config |
| |
|
| | |
| | self.embedding = nn.Embedding(config.vocab_size, config.d_model) |
| |
|
| | |
| | if config.use_hybrid and config.hybrid_pattern: |
| | pattern = config.hybrid_pattern.strip().split() |
| | if len(pattern) != config.n_layers: |
| | raise ValueError( |
| | f"hybrid_pattern has {len(pattern)} entries but " |
| | f"n_layers={config.n_layers}" |
| | ) |
| | layers: list[nn.Module] = [] |
| | |
| | self._layer_types: list[str] = pattern |
| | for layer_type in pattern: |
| | if layer_type == "M": |
| | layers.append(Mamba2Block( |
| | d_model=config.d_model, |
| | d_state=config.mamba_d_state, |
| | head_dim=config.mamba_head_dim, |
| | expand=config.mamba_expand, |
| | conv_kernel=config.mamba_conv_kernel, |
| | n_groups=config.mamba_n_groups, |
| | chunk_size=config.mamba_chunk_size, |
| | )) |
| | elif layer_type == "A": |
| | layers.append(TransformerBlock(config)) |
| | else: |
| | raise ValueError( |
| | f"Unknown layer type '{layer_type}' in hybrid_pattern. " |
| | f"Use 'M' (Mamba) or 'A' (Attention)." |
| | ) |
| | self.layers = nn.ModuleList(layers) |
| | else: |
| | self._layer_types = ["A"] * config.n_layers |
| | self.layers = nn.ModuleList( |
| | [TransformerBlock(config) for _ in range(config.n_layers)] |
| | ) |
| |
|
| | |
| | self.norm = RMSNorm(config.d_model) |
| | |
| | self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
| |
|
| | |
| | self.lm_head.weight = self.embedding.weight |
| |
|
| | |
| | self.rope = RotaryEmbedding( |
| | dim=config.head_dim, |
| | max_seq_len=config.max_seq_len, |
| | theta=config.rope_theta, |
| | ) |
| |
|
| | |
| | self.apply(self._init_weights) |
| |
|
| | |
| | |
| | |
| |
|
| | @staticmethod |
| | def _init_weights(module: nn.Module) -> None: |
| | """Apply standard initialisation: |
| | - Linear / Embedding weights: N(0, 0.02) |
| | - Bias parameters: zeros |
| | - te.Linear / te.LayerNormMLP: skipped (TE manages its own init) |
| | - Mamba2Block: skipped (manages its own init) |
| | """ |
| | |
| | if HAS_TE and isinstance(module, (te.Linear, te.LayerNormMLP)): |
| | return |
| | |
| | if isinstance(module, Mamba2Block): |
| | return |
| | if isinstance(module, (nn.Linear, nn.Embedding)): |
| | nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | if isinstance(module, nn.Linear) and module.bias is not None: |
| | nn.init.zeros_(module.bias) |
| |
|
| | |
| | |
| | |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | targets: Optional[torch.Tensor] = None, |
| | ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | """ |
| | Args: |
| | input_ids: (B, T) long tensor of token indices |
| | targets: (B, T) long tensor of target token indices, or None. |
| | Use -1 (ignore_index) to mask positions. |
| | |
| | Returns: |
| | logits: (B, T, vocab_size) |
| | loss: scalar cross-entropy loss, or None if targets is None |
| | """ |
| | B, T = input_ids.shape |
| | device = input_ids.device |
| |
|
| | |
| | x = self.embedding(input_ids) |
| |
|
| | |
| | |
| | cos, sin = self.rope(T, device) |
| |
|
| | |
| | for layer, ltype in zip(self.layers, self._layer_types): |
| | if ltype == "M": |
| | x = layer(x) |
| | else: |
| | x = layer(x, cos, sin) |
| |
|
| | |
| | x = self.norm(x) |
| |
|
| | |
| | logits = self.lm_head(x) |
| |
|
| | |
| | loss: Optional[torch.Tensor] = None |
| | if targets is not None: |
| | loss = F.cross_entropy( |
| | logits.view(-1, logits.size(-1)), |
| | targets.view(-1), |
| | ignore_index=-1, |
| | ) |
| |
|
| | return logits, loss |
| |
|
| | |
| | |
| | |
| |
|
| | @property |
| | def num_params(self) -> int: |
| | """Number of trainable parameters.""" |
| | return sum(p.numel() for p in self.parameters() if p.requires_grad) |
| |
|
| | def get_input_embeddings(self) -> nn.Embedding: |
| | """HuggingFace-compatible accessor for the token embedding layer.""" |
| | return self.embedding |
| |
|
| | |
| | |
| | |
| |
|
| | @classmethod |
| | def from_config(cls, config: LMConfig) -> "LLM": |
| | """Construct an LLM from an LMConfig instance.""" |
| | return cls(config) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, path: str | Path) -> "LLM": |
| | """Load model from a checkpoint directory. |
| | |
| | Supports two formats (auto-detected): |
| | 1. Custom: config.yaml + model.pt |
| | 2. HuggingFace: config.json + model.safetensors (LlamaForCausalLM) |
| | """ |
| | path = Path(path) |
| |
|
| | |
| | if (path / "config.yaml").exists(): |
| | config = LMConfig.from_yaml(path / "config.yaml") |
| | model = cls(config) |
| | state_dict = torch.load( |
| | path / "model.pt", |
| | map_location="cpu", |
| | weights_only=True, |
| | ) |
| | model.load_state_dict(state_dict) |
| | return model |
| |
|
| | |
| | if (path / "config.json").exists(): |
| | config = LMConfig.from_hf_config(path / "config.json") |
| | model = cls(config) |
| | hf_sd = _load_hf_state_dict(path) |
| | our_sd = _convert_hf_to_custom(hf_sd, config) |
| | model.load_state_dict(our_sd) |
| | return model |
| |
|
| | raise FileNotFoundError( |
| | f"No config.yaml or config.json found in {path}" |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | def save_pretrained(self, path: str | Path) -> None: |
| | """Save config and model weights to a directory. |
| | |
| | Creates: |
| | <path>/config.yaml |
| | <path>/model.pt |
| | """ |
| | path = Path(path) |
| | path.mkdir(parents=True, exist_ok=True) |
| | self.config.to_yaml(path / "config.yaml") |
| | torch.save(self.state_dict(), path / "model.pt") |
| |
|