xl-zhao's picture
Update README.md
8eeae01 verified
---
license: mit
---
# **Scaling Reasoning without Attention**
[![ArXiv](https://img.shields.io/badge/arXiv-2505.22425-red)](http://arxiv.org/abs/2505.22425)
[![GitHub](https://img.shields.io/badge/GitHub-PromptCoT-blue)](https://github.com/inclusionAI/PromptCoT)
---
## πŸš€ Overview
**PromptCoT-Mamba** establishes the first **attention-free foundation model** capable of surpassing strong Transformer baselines across a broad suite of competition-level math and code reasoning tasks. Built on the **Mamba-2** architecture and trained through a structured, two-stage curriculum using the [**PromptCoT**](http://arxiv.org/abs/2503.02324) pipeline, it delivers **high accuracy with constant-memory inference**, eliminating the need for KV caching.
---
## πŸ“ˆ Key Results
### πŸ”Ή General Performance
| Model | MATH-500 | AIME 24 | AIME 25 | OlympiadBench | HumanEval | HumanEval+ | Livecodebench |
| ---------------------- | -------- | -------- | -------- | ------------- | --------- | ---------- | ------------- |
| **PromptCoT-Mamba-7B** | 84.6 | **35.2** | **24.6** | 50.7 | 81.7 | 75.0 | **29.9** |
| Gemma3-27B | **89.0** | 32.6 | 24.0 | **54.2** | **86.0** | **78.0** | 26.9 |
| Gemma3-12B | 83.8 | 22.9 | 19.2 | 49.9 | 81.1 | 73.2 | 22.2 |
| Sky-T1-7B | 85.0 | 19.2 | 19.2 | 49.2 | 41.5 | 37.2 | 18.3 |
| S1.1-7B | 82.0 | 19.2 | 17.5 | 43.1 | 64.0 | 56.7 | 13.3 |
| Bespoke-Stratos-7B | 81.2 | 18.3 | 16.3 | 45.0 | 73.2 | 68.3 | 8.6 |
| Nemotron-H-8B | 77.6 | -- | -- | -- | 79.3 | 74.4 | -- |
| M1-3B | 81.7 | 23.0 | 22.0 | 43.6 | -- | -- | -- |
> πŸ” **PromptCoT-Mamba-7B** consistently outperforms all 7B-scale Transformer and hybrid Mamba-Transformer baselines across all tasks.
---
### πŸ”Ή Math Specialization vs. Generalist
| Model | MATH-500 | AIME 24 | AIME 25 | OlympiadBench | HumanEval | HumanEval+ | Livecodebench |
| --------------------------- | -------- | -------- | -------- | ------------- | --------- | ---------- | ------------- |
| **PromptCoT-Mamba-Math-7B** | **88.0** | **42.9** | **30.8** | **52.1** | 71.3 | 66.5 | 20.3 |
| PromptCoT-Mamba-7B | 84.6 | 35.2 | 24.6 | 50.7 | **81.7** | **75.0** | **29.9** |
> 🎯 The math-specialized variant improves AIME 24 by **+7.7%** and AIME 25 by **+6.2%**, with a slight trade-off in code-related performance.
---
### ⚑ Inference Efficiency
Using `vLLM` under constrained memory, PromptCoT-Mamba-7B demonstrates substantial speedups over the S1.1-7B Transformer baseline:
* πŸ’‘ **3.66Γ— faster** at long-sequence generation on **24GB GPU**
* πŸ’‘ **1.69Γ— faster** under **72GB memory**
> βš™οΈ Practical for cost-sensitive or long-context inference workloads at scale.
---
## πŸ§ͺ Quick Start
### πŸ”§ Install Requirements
```bash
pip install transformers vllm torch accelerate
```
### 🧠 Load and Run the Model
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "xl-zhao/PromptCoT-Mamba-Math-7B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
problem_statement = (
"A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?"
)
prompt = (
f"<|im_start|>user\n{problem_statement}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n"
"<|im_start|>assistant\n"
)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
output = model.generate(**inputs, max_length=65536, temperature=0.8)
generated_solution = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_solution)
```
---
## ⚑ Fast Inference with vLLM
```python
from vllm import LLM, SamplingParams
model_name = "xl-zhao/PromptCoT-Mamba-Math-7B"
llm = LLM(model=model_name, tensor_parallel_size=1)
problem_statement = (
"A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?"
)
prompt = (
f"<|im_start|>user\n{problem_statement}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n"
"<|im_start|>assistant\n"
)
sampling_params = SamplingParams(temperature=0.8, max_tokens=65536)
outputs = llm.generate([prompt], sampling_params)
print(outputs[0].outputs[0].text)
```
---
## πŸ“œ Citation
```bibtex
@article{zhao2025scaling,
author = {Xueliang Zhao and Wei Wu and Lingpeng Kong},
title = {Scaling Reasoning without Attention},
journal = {arXiv preprint arXiv:2505.22425},
year = {2025},
url = {https://arxiv.org/abs/2505.22425}
}
```