ALTAMBA 1.78B
ALTAMBA (Alternating Mamba) is a 1.78B parameter dual-path language model that runs a Transformer and State Space Model (Mamba-2) in parallel at every layer, blending them through learned destructive and constructive interference.
Architecture
| Parameter | Value |
|---|---|
| Parameters | 1.78B |
| Layers | 12 |
| d_model | 2560 |
| Attention heads | 8 |
| FFN expansion | 4x |
| Denoiser scale | 0.75x |
| Context length | 256 |
| Vocab size | 100,288 |
| Tokenizer | tiktoken cl100k_base |
Key Innovations
- Dual-path: Transformer + Mamba-2 SSM run in parallel at every layer
- Peristaltic Normalization: Alternating Post-LayerNorm (even layers, variance clamping for denoising) and Pre-LayerNorm (odd layers, gradient flow for blending)
- The Razor: Learnable blending operation
y = s * [(c - W1) * LN(x_denoiser) + (c' - W2) * x_main]with per-layer parameters and output scaling - Role reversal: Even layers use SSM as main signal / Transformer as denoiser; odd layers reverse roles
- Bottleneck compression: 0.75x denoiser paths for parameter efficiency
- Bounded-dt fix: Sigmoid-bounded discretization timestep for Mamba-2 stability
Results
| Scale | Baseline Val Loss | ALTAMBA Val Loss | Improvement |
|---|---|---|---|
| 402M | 3.1866 | 2.8886 | 9.35% |
| 1.08B | 2.9771 | 2.6974 | 9.40% |
| 1.78B | 2.4427 | 2.2554 | 7.66% |
Baselines are parameter-matched Jamba (1:7 attention-to-Mamba ratio). Trained on Common Pile (arXiv subset).
Checkpoint Details
- File:
best_model_altamba_fp16.pt - Format: PyTorch state dict (fp16)
- Training step: 8625
- Validation loss: 2.2554
- State dict keys: Direct
TransformerSSMDenoisekeys (no wrapper prefix)
Usage
import torch
import tiktoken
from transformer_ssm_denoise import TransformerSSMDenoise
# Tokenizer
enc = tiktoken.get_encoding("cl100k_base")
vocab_size = 100288
# Create model
model = TransformerSSMDenoise(
vocab_size=vocab_size,
d_model=2560,
n_layers=12,
n_heads=8,
d_ff=10240,
dropout=0.0,
init_w1=1.5,
init_w2=-0.5,
use_denoising=True,
baseline_type="dual",
use_scaling=True,
denoiser_scale_ssm=0.75,
denoiser_scale_transformer=0.75,
use_odd_ssm=True,
)
# Load checkpoint
state_dict = torch.load("best_model_altamba_fp16.pt", map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
model.half().eval().cuda()
# Generate
prompt = "The fundamental theorem"
tokens = enc.encode(prompt)
input_ids = torch.tensor([tokens], dtype=torch.long, device="cuda")
with torch.no_grad():
for _ in range(100):
logits = model(input_ids[:, -512:])
next_logits = logits[:, -1, :] / 0.8
probs = torch.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
print(enc.decode(input_ids[0].tolist()))
Dependencies
- PyTorch >= 2.1.0
- mamba-ssm (requires CUDA)
- tiktoken
Architecture code available at github.com/altamba/altamba.
Training Configuration
- Optimizer: AdamW (lr=1e-4, cosine annealing)
- Batch size: 8 (gradient accumulation 8, effective 64)
- Precision: BF16 mixed precision
- Gradient clipping: 1.0
- Dropout: 0.2
- Hardware: NVIDIA RTX PRO 6000 Blackwell Server Edition (96GB)
Citation
@article{seto2026altamba,
title={ALTAMBA: Alternating Mamba with Peristaltic Normalization},
author={Seto, Scott},
year={2026},
doi={10.5281/zenodo.18521311}
}
License
Apache 2.0