|
|
---
|
|
|
license: apache-2.0
|
|
|
language:
|
|
|
- en
|
|
|
tags:
|
|
|
- mamba
|
|
|
- ssm
|
|
|
- transformer
|
|
|
- text-generation
|
|
|
- hybrid
|
|
|
library_name: pytorch
|
|
|
pipeline_tag: text-generation
|
|
|
---
|
|
|
|
|
|
# ALTAMBA 1.78B
|
|
|
|
|
|
**ALTAMBA** (Alternating Mamba) is a 1.78B parameter dual-path language model that runs a Transformer and State Space Model (Mamba-2) in parallel at every layer, blending them through learned destructive and constructive interference.
|
|
|
|
|
|
## Architecture
|
|
|
|
|
|
| Parameter | Value |
|
|
|
|-----------|-------|
|
|
|
| Parameters | 1.78B |
|
|
|
| Layers | 12 |
|
|
|
| d_model | 2560 |
|
|
|
| Attention heads | 8 |
|
|
|
| FFN expansion | 4x |
|
|
|
| Denoiser scale | 0.75x |
|
|
|
| Context length | 256 |
|
|
|
| Vocab size | 100,288 |
|
|
|
| Tokenizer | tiktoken cl100k_base |
|
|
|
|
|
|
### Key Innovations
|
|
|
|
|
|
- **Dual-path**: Transformer + Mamba-2 SSM run in parallel at every layer
|
|
|
- **Peristaltic Normalization**: Alternating Post-LayerNorm (even layers, variance clamping for denoising) and Pre-LayerNorm (odd layers, gradient flow for blending)
|
|
|
- **The Razor**: Learnable blending operation `y = s * [(c - W1) * LN(x_denoiser) + (c' - W2) * x_main]` with per-layer parameters and output scaling
|
|
|
- **Role reversal**: Even layers use SSM as main signal / Transformer as denoiser; odd layers reverse roles
|
|
|
- **Bottleneck compression**: 0.75x denoiser paths for parameter efficiency
|
|
|
- **Bounded-dt fix**: Sigmoid-bounded discretization timestep for Mamba-2 stability
|
|
|
|
|
|
## Results
|
|
|
|
|
|
| Scale | Baseline Val Loss | ALTAMBA Val Loss | Improvement |
|
|
|
|-------|-------------------|------------------|-------------|
|
|
|
| 402M | 3.1866 | 2.8886 | 9.35% |
|
|
|
| 1.08B | 2.9771 | 2.6974 | 9.40% |
|
|
|
| 1.78B | 2.4427 | 2.2554 | 7.66% |
|
|
|
|
|
|
Baselines are parameter-matched Jamba (1:7 attention-to-Mamba ratio). Trained on Common Pile (arXiv subset).
|
|
|
|
|
|
## Checkpoint Details
|
|
|
|
|
|
- **File**: `best_model_altamba_fp16.pt`
|
|
|
- **Format**: PyTorch state dict (fp16)
|
|
|
- **Training step**: 8625
|
|
|
- **Validation loss**: 2.2554
|
|
|
- **State dict keys**: Direct `TransformerSSMDenoise` keys (no wrapper prefix)
|
|
|
|
|
|
## Usage
|
|
|
|
|
|
```python
|
|
|
import torch
|
|
|
import tiktoken
|
|
|
from transformer_ssm_denoise import TransformerSSMDenoise
|
|
|
|
|
|
# Tokenizer
|
|
|
enc = tiktoken.get_encoding("cl100k_base")
|
|
|
vocab_size = 100288
|
|
|
|
|
|
# Create model
|
|
|
model = TransformerSSMDenoise(
|
|
|
vocab_size=vocab_size,
|
|
|
d_model=2560,
|
|
|
n_layers=12,
|
|
|
n_heads=8,
|
|
|
d_ff=10240,
|
|
|
dropout=0.0,
|
|
|
init_w1=1.5,
|
|
|
init_w2=-0.5,
|
|
|
use_denoising=True,
|
|
|
baseline_type="dual",
|
|
|
use_scaling=True,
|
|
|
denoiser_scale_ssm=0.75,
|
|
|
denoiser_scale_transformer=0.75,
|
|
|
use_odd_ssm=True,
|
|
|
)
|
|
|
|
|
|
# Load checkpoint
|
|
|
state_dict = torch.load("best_model_altamba_fp16.pt", map_location="cpu", weights_only=True)
|
|
|
model.load_state_dict(state_dict)
|
|
|
model.half().eval().cuda()
|
|
|
|
|
|
# Generate
|
|
|
prompt = "The fundamental theorem"
|
|
|
tokens = enc.encode(prompt)
|
|
|
input_ids = torch.tensor([tokens], dtype=torch.long, device="cuda")
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for _ in range(100):
|
|
|
logits = model(input_ids[:, -512:])
|
|
|
next_logits = logits[:, -1, :] / 0.8
|
|
|
probs = torch.softmax(next_logits, dim=-1)
|
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
input_ids = torch.cat([input_ids, next_token], dim=1)
|
|
|
|
|
|
print(enc.decode(input_ids[0].tolist()))
|
|
|
```
|
|
|
|
|
|
## Dependencies
|
|
|
|
|
|
- PyTorch >= 2.1.0
|
|
|
- [mamba-ssm](https://github.com/state-spaces/mamba) (requires CUDA)
|
|
|
- tiktoken
|
|
|
|
|
|
Architecture code available at [github.com/altamba/altamba](https://github.com/altamba/altamba).
|
|
|
|
|
|
## Training Configuration
|
|
|
|
|
|
- **Optimizer**: AdamW (lr=1e-4, cosine annealing)
|
|
|
- **Batch size**: 8 (gradient accumulation 8, effective 64)
|
|
|
- **Precision**: BF16 mixed precision
|
|
|
- **Gradient clipping**: 1.0
|
|
|
- **Dropout**: 0.2
|
|
|
- **Hardware**: NVIDIA RTX PRO 6000 Blackwell Server Edition (96GB)
|
|
|
|
|
|
## Citation
|
|
|
|
|
|
```bibtex
|
|
|
@article{seto2026altamba,
|
|
|
title={ALTAMBA: Alternating Mamba with Peristaltic Normalization},
|
|
|
author={Seto, Scott},
|
|
|
year={2026},
|
|
|
doi={10.5281/zenodo.18521311}
|
|
|
}
|
|
|
```
|
|
|
|
|
|
## License
|
|
|
|
|
|
Apache 2.0
|
|
|
|