| """Mamba-3 LM assembly that wires the official mamba-og Mamba3 mixer into a |
| prenorm Llama-style stack. The reason we don't reuse `mixer_seq_simple.py` |
| verbatim: its `create_block` only knows about Mamba1/Mamba2. We replicate the |
| same Block layout and weight init so the model is structurally identical to |
| what the Mamba-3 authors trained, just with Mamba3 as the mixer.""" |
|
|
| from __future__ import annotations |
|
|
| import math |
| from collections import namedtuple |
| from functools import partial |
| from pathlib import Path |
| import sys |
|
|
| import torch |
| import torch.nn as nn |
|
|
| |
| _REPO_ROOT = Path(__file__).resolve().parents[1] |
| _MAMBA_OG = _REPO_ROOT / "mamba-og" |
| if str(_MAMBA_OG) not in sys.path: |
| sys.path.insert(0, str(_MAMBA_OG)) |
|
|
| from mamba_ssm.modules.block import Block |
| from mamba_ssm.modules.mamba3 import Mamba3 |
| from mamba_ssm.modules.mlp import GatedMLP |
| from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn |
|
|
| CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"]) |
|
|
|
|
| def _init_weights( |
| module: nn.Module, |
| n_layer: int, |
| initializer_range: float = 0.02, |
| rescale_prenorm_residual: bool = True, |
| n_residuals_per_layer: int = 1, |
| ): |
| if isinstance(module, nn.Linear): |
| if module.bias is not None and not getattr(module.bias, "_no_reinit", False): |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, std=initializer_range) |
|
|
| if rescale_prenorm_residual: |
| for name, p in module.named_parameters(): |
| if name in ("out_proj.weight", "fc2.weight"): |
| nn.init.kaiming_uniform_(p, a=math.sqrt(5)) |
| with torch.no_grad(): |
| p /= math.sqrt(n_residuals_per_layer * n_layer) |
|
|
|
|
| def _build_block( |
| layer_idx: int, |
| d_model: int, |
| d_intermediate: int, |
| ssm_cfg: dict, |
| norm_epsilon: float, |
| rms_norm: bool, |
| residual_in_fp32: bool, |
| fused_add_norm: bool, |
| device, |
| dtype, |
| ): |
| factory_kwargs = {"device": device, "dtype": dtype} |
|
|
| mixer_cls = partial(Mamba3, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) |
|
|
| norm_cls = partial( |
| RMSNorm if rms_norm else nn.LayerNorm, eps=norm_epsilon, **factory_kwargs |
| ) |
|
|
| if d_intermediate == 0: |
| mlp_cls = nn.Identity |
| else: |
| mlp_cls = partial( |
| GatedMLP, |
| hidden_features=d_intermediate, |
| out_features=d_model, |
| **factory_kwargs, |
| ) |
|
|
| block = Block( |
| d_model, |
| mixer_cls, |
| mlp_cls, |
| norm_cls=norm_cls, |
| fused_add_norm=fused_add_norm, |
| residual_in_fp32=residual_in_fp32, |
| ) |
| block.layer_idx = layer_idx |
| return block |
|
|
|
|
| class Mamba3LM(nn.Module): |
| """Mamba-3 language model (prenorm stack + tied lm_head).""" |
|
|
| def __init__( |
| self, |
| d_model: int, |
| n_layer: int, |
| d_intermediate: int, |
| vocab_size: int, |
| ssm_cfg: dict, |
| rms_norm: bool = True, |
| residual_in_fp32: bool = True, |
| fused_add_norm: bool = True, |
| norm_epsilon: float = 1e-5, |
| pad_vocab_multiple: int = 8, |
| tie_embeddings: bool = True, |
| initializer_range: float = 0.02, |
| device=None, |
| dtype=None, |
| ): |
| super().__init__() |
| if vocab_size % pad_vocab_multiple != 0: |
| vocab_size += pad_vocab_multiple - (vocab_size % pad_vocab_multiple) |
| self.vocab_size = vocab_size |
| self.tie_embeddings = tie_embeddings |
| self.residual_in_fp32 = residual_in_fp32 |
| self.fused_add_norm = fused_add_norm |
|
|
| factory_kwargs = {"device": device, "dtype": dtype} |
| self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) |
| self.layers = nn.ModuleList( |
| [ |
| _build_block( |
| layer_idx=i, |
| d_model=d_model, |
| d_intermediate=d_intermediate, |
| ssm_cfg=ssm_cfg, |
| norm_epsilon=norm_epsilon, |
| rms_norm=rms_norm, |
| residual_in_fp32=residual_in_fp32, |
| fused_add_norm=fused_add_norm, |
| device=device, |
| dtype=dtype, |
| ) |
| for i in range(n_layer) |
| ] |
| ) |
| self.norm_f = (RMSNorm if rms_norm else nn.LayerNorm)( |
| d_model, eps=norm_epsilon, **factory_kwargs |
| ) |
| self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) |
|
|
| self.apply( |
| partial( |
| _init_weights, |
| n_layer=n_layer, |
| initializer_range=initializer_range, |
| n_residuals_per_layer=1 if d_intermediate == 0 else 2, |
| ) |
| ) |
| if tie_embeddings: |
| self.lm_head.weight = self.embedding.weight |
|
|
| def forward(self, input_ids, labels=None): |
| hidden_states = self.embedding(input_ids) |
| residual = None |
| for layer in self.layers: |
| hidden_states, residual = layer(hidden_states, residual) |
|
|
| if not self.fused_add_norm: |
| residual = (hidden_states + residual) if residual is not None else hidden_states |
| hidden_states = self.norm_f(residual.to(self.norm_f.weight.dtype)) |
| else: |
| hidden_states = layer_norm_fn( |
| hidden_states, |
| self.norm_f.weight, |
| self.norm_f.bias, |
| eps=self.norm_f.eps, |
| residual=residual, |
| prenorm=False, |
| residual_in_fp32=self.residual_in_fp32, |
| is_rms_norm=isinstance(self.norm_f, RMSNorm), |
| ) |
|
|
| logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = labels[:, 1:].contiguous() |
| loss = nn.functional.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)).float(), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) |
|
|
| return CausalLMOutput(logits=logits, loss=loss) |
|
|
| def num_params(self, trainable_only: bool = True) -> int: |
| params = (p for p in self.parameters() if (p.requires_grad or not trainable_only)) |
| |
| seen = set() |
| total = 0 |
| for p in params: |
| if id(p) in seen: |
| continue |
| seen.add(id(p)) |
| total += p.numel() |
| return total |
|
|
|
|
| def build_model_from_config(cfg: dict, device=None, dtype=None) -> Mamba3LM: |
| """Translate a parsed YAML config into a Mamba3LM. |
| |
| If ``architecture.bc_stabilizer`` is present and not ``"bcnorm"``, the |
| stock RMSNormGated B/C normalizers are swapped out for the named |
| element-wise stabilizer (DySoftSign, DyT, DyISRU, Derf, ...). The swap |
| happens *after* model construction so the rest of the mixer (in_proj, |
| biases, RoPE, SSD kernel) is bit-identical to the BCNorm baseline. |
| """ |
| m = cfg["model"] |
| a = cfg["architecture"] |
| k = cfg["kernels"] |
|
|
| ssm_cfg = dict( |
| d_state=m["d_state"], |
| expand=m["expand"], |
| headdim=m["head_dim"], |
| ngroups=m["ngroups"], |
| rope_fraction=a["rope_fraction"], |
| is_outproj_norm=a["is_outproj_norm"], |
| is_mimo=a["is_mimo"], |
| mimo_rank=a["mimo_rank"], |
| chunk_size=k["chunk_size"], |
| ) |
|
|
| model = Mamba3LM( |
| d_model=m["d_model"], |
| n_layer=m["n_layers"], |
| d_intermediate=m["d_intermediate"], |
| vocab_size=m["vocab_size"], |
| ssm_cfg=ssm_cfg, |
| rms_norm=m["rms_norm"], |
| residual_in_fp32=m["residual_in_fp32"], |
| fused_add_norm=m["fused_add_norm"], |
| norm_epsilon=m["norm_epsilon"], |
| pad_vocab_multiple=m["pad_vocab_multiple"], |
| tie_embeddings=m["tie_embeddings"], |
| initializer_range=m["initializer_range"], |
| device=device, |
| dtype=dtype, |
| ) |
|
|
| stabilizer = str(a.get("bc_stabilizer", "bcnorm")).lower() |
| if stabilizer != "bcnorm": |
| if str(_REPO_ROOT / "src") not in sys.path: |
| sys.path.insert(0, str(_REPO_ROOT / "src")) |
| from nfmamba.adapters.bc_stabilizer import install_bc_stabilizer |
|
|
| report = install_bc_stabilizer( |
| model, |
| name=stabilizer, |
| stabilize_b=bool(a.get("stabilize_b", True)), |
| stabilize_c=bool(a.get("stabilize_c", True)), |
| squash_before_bias=bool(a.get("squash_before_bias", False)), |
| ) |
| print( |
| f"[model] BC stabilizer = {report.name!r} " |
| f"(replaced={report.replaced}, B={report.stabilize_b}, C={report.stabilize_c}, " |
| f"squash_before_bias={report.squash_before_bias})", |
| flush=True, |
| ) |
|
|
| if dtype is not None: |
| model.to(dtype=dtype) |
| if device is not None: |
| model.to(device=device) |
|
|
| return model |
|
|