--- license: apache-2.0 language: - en tags: - mamba - ssm - transformer - text-generation - hybrid library_name: pytorch pipeline_tag: text-generation --- # ALTAMBA 1.78B **ALTAMBA** (Alternating Mamba) is a 1.78B parameter dual-path language model that runs a Transformer and State Space Model (Mamba-2) in parallel at every layer, blending them through learned destructive and constructive interference. ## Architecture | Parameter | Value | |-----------|-------| | Parameters | 1.78B | | Layers | 12 | | d_model | 2560 | | Attention heads | 8 | | FFN expansion | 4x | | Denoiser scale | 0.75x | | Context length | 256 | | Vocab size | 100,288 | | Tokenizer | tiktoken cl100k_base | ### Key Innovations - **Dual-path**: Transformer + Mamba-2 SSM run in parallel at every layer - **Peristaltic Normalization**: Alternating Post-LayerNorm (even layers, variance clamping for denoising) and Pre-LayerNorm (odd layers, gradient flow for blending) - **The Razor**: Learnable blending operation `y = s * [(c - W1) * LN(x_denoiser) + (c' - W2) * x_main]` with per-layer parameters and output scaling - **Role reversal**: Even layers use SSM as main signal / Transformer as denoiser; odd layers reverse roles - **Bottleneck compression**: 0.75x denoiser paths for parameter efficiency - **Bounded-dt fix**: Sigmoid-bounded discretization timestep for Mamba-2 stability ## Results | Scale | Baseline Val Loss | ALTAMBA Val Loss | Improvement | |-------|-------------------|------------------|-------------| | 402M | 3.1866 | 2.8886 | 9.35% | | 1.08B | 2.9771 | 2.6974 | 9.40% | | 1.78B | 2.4427 | 2.2554 | 7.66% | Baselines are parameter-matched Jamba (1:7 attention-to-Mamba ratio). Trained on Common Pile (arXiv subset). ## Checkpoint Details - **File**: `best_model_altamba_fp16.pt` - **Format**: PyTorch state dict (fp16) - **Training step**: 8625 - **Validation loss**: 2.2554 - **State dict keys**: Direct `TransformerSSMDenoise` keys (no wrapper prefix) ## Usage ```python import torch import tiktoken from transformer_ssm_denoise import TransformerSSMDenoise # Tokenizer enc = tiktoken.get_encoding("cl100k_base") vocab_size = 100288 # Create model model = TransformerSSMDenoise( vocab_size=vocab_size, d_model=2560, n_layers=12, n_heads=8, d_ff=10240, dropout=0.0, init_w1=1.5, init_w2=-0.5, use_denoising=True, baseline_type="dual", use_scaling=True, denoiser_scale_ssm=0.75, denoiser_scale_transformer=0.75, use_odd_ssm=True, ) # Load checkpoint state_dict = torch.load("best_model_altamba_fp16.pt", map_location="cpu", weights_only=True) model.load_state_dict(state_dict) model.half().eval().cuda() # Generate prompt = "The fundamental theorem" tokens = enc.encode(prompt) input_ids = torch.tensor([tokens], dtype=torch.long, device="cuda") with torch.no_grad(): for _ in range(100): logits = model(input_ids[:, -512:]) next_logits = logits[:, -1, :] / 0.8 probs = torch.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_token], dim=1) print(enc.decode(input_ids[0].tolist())) ``` ## Dependencies - PyTorch >= 2.1.0 - [mamba-ssm](https://github.com/state-spaces/mamba) (requires CUDA) - tiktoken Architecture code available at [github.com/altamba/altamba](https://github.com/altamba/altamba). ## Training Configuration - **Optimizer**: AdamW (lr=1e-4, cosine annealing) - **Batch size**: 8 (gradient accumulation 8, effective 64) - **Precision**: BF16 mixed precision - **Gradient clipping**: 1.0 - **Dropout**: 0.2 - **Hardware**: NVIDIA RTX PRO 6000 Blackwell Server Edition (96GB) ## Citation ```bibtex @article{seto2026altamba, title={ALTAMBA: Alternating Mamba with Peristaltic Normalization}, author={Seto, Scott}, year={2026}, doi={10.5281/zenodo.18521311} } ``` ## License Apache 2.0