Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,108 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
datasets:
|
| 4 |
+
- Zigeng/DMax-LLaDA-2.0-Mini-Math-Trajectories
|
| 5 |
+
base_model:
|
| 6 |
+
- inclusionAI/LLaDA2.0-mini
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
<div align="center">
|
| 10 |
+
<h1>π DMax: Aggressive Parallel Decoding for dLLMs</h1>
|
| 11 |
+
<div align="center">
|
| 12 |
+
<a href="https://github.com/czg1225/DMax/blob/main/LICENSE">
|
| 13 |
+
<img alt="Apache" src="https://img.shields.io/badge/License-Apache-4E94CE.svg">
|
| 14 |
+
</a>
|
| 15 |
+
<a href="https://github.com/czg1225/DMax">
|
| 16 |
+
<img src="https://img.shields.io/badge/Paper-Arxiv-darkred.svg" alt="Paper">
|
| 17 |
+
</a>
|
| 18 |
+
<a href="https://github.com/czg1225/DMax">
|
| 19 |
+
<img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub">
|
| 20 |
+
</a>
|
| 21 |
+
</div>
|
| 22 |
+
</div>
|
| 23 |
+
|
| 24 |
+
> **DMax: Aggressive Parallel Decoding for dLLMs**
|
| 25 |
+
> [Zigeng Chen](https://czg1225.github.io/chenzigeng99/), [Gongfan Fang](https://fangggf.github.io/), [Xinyin Ma](https://horseee.github.io/), [Ruonan Yu](https://scholar.google.com/citations?user=UHP95egAAAAJ&hl=en), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
|
| 26 |
+
> [xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
## πͺ Highlights
|
| 30 |
+
|
| 31 |
+
- **Aggressive Decoding Parallelism**: Achieves 6.0 TPF on math and reasoning tasks and 6.6 TPF on code tasks while preserving accuracy.
|
| 32 |
+
- **Self-Revising dLLM**: Extends a pretrained MDLM into a UDLM with an intrinsic ability to revise its own erroneous predictions during decoding.
|
| 33 |
+
- **Soft Parallel Decoding**: Uses interpolation between mask and token embeddings to propagate confidence priors from previous steps.
|
| 34 |
+
|
| 35 |
+
<div align="center">
|
| 36 |
+
<img src="assets/tradeoff.png" width="90%" />
|
| 37 |
+
<br>
|
| 38 |
+
<em>Superior Parallelism-Accuracy Trade-off, Increased TPF with Maintained Accuracy.</em>
|
| 39 |
+
</div>
|
| 40 |
+
|
| 41 |
+
## π‘ Introduction
|
| 42 |
+
|
| 43 |
+
We present DMax, a new paradigm for efficient dLLMs. It mitigates error accumulation in parallel decoding, enabling aggressive decoding parallelism while preserving generation quality. Unlike conventional masked dLLMs that decode through a binary mask-to-token transition, DMax reformulates decoding as a progressive self-refinement from mask embeddings to token embeddings. At the core of our approach is On-Policy Uniform Training, a novel training strategy that efficiently unifies masked and uniform dLLMs, equipping the model to recover clean tokens from both masked inputs and its own erroneous predictions. Building on this foundation, we further intoduce Soft Parallel Decoding. Extensive experiments across a variety of benchmarks demonstrate the effectiveness of DMax.
|
| 44 |
+
|
| 45 |
+
<!--  -->
|
| 46 |
+
<div align="center">
|
| 47 |
+
<img src="assets/train.png" width="100%" />
|
| 48 |
+
<br>
|
| 49 |
+
<em>Overview of the On-Policy Uniform Training.</em>
|
| 50 |
+
</div>
|
| 51 |
+
|
| 52 |
+
## π» Model and Datasets
|
| 53 |
+
|
| 54 |
+
| Model | Description | Source Model | Link |
|
| 55 |
+
| --- | --- | --- | --- |
|
| 56 |
+
| π€ DMax-Math-16B | Highly parallel dLLM for math and reasoning. | LLaDA-2.0-mini | [Hugging Face](https://huggingface.co/Zigeng/DMax-Math-16B) |
|
| 57 |
+
| π€ DMax-Coder-16B | Highly parallel dLLM for code generation. | LLaDA-2.0-mini | [Hugging Face](https://huggingface.co/Zigeng/DMax-Coder-16B) |
|
| 58 |
+
|
| 59 |
+
| Dataset | Description | Link |
|
| 60 |
+
| --- | --- | --- |
|
| 61 |
+
| π DMax-Math-Training-Data | Trajectories on math problems generated by LLaDA-2.0-mini | [Hugging Face](https://huggingface.co/datasets/Zigeng/DMax-LLaDA-2.0-Mini-Math-Trajectories) |
|
| 62 |
+
| π DMax-Code-Training-Data | Trajectories on code problems generated by LLaDA-2.0-mini | [Hugging Face](https://huggingface.co/datasets/Zigeng/DMax-LLaDA-2.0-Mini-Code-Trajectories) |
|
| 63 |
+
|
| 64 |
+
## π Quick Start
|
| 65 |
+
|
| 66 |
+
```python
|
| 67 |
+
import torch
|
| 68 |
+
from transformers import AutoModelForCausalLM
|
| 69 |
+
from transformers import AutoTokenizer
|
| 70 |
+
|
| 71 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 72 |
+
"Zigeng/DMax-Math-16B", trust_remote_code=True, device_map="cuda:0"
|
| 73 |
+
)
|
| 74 |
+
model = model.to(torch.bfloat16)
|
| 75 |
+
model.eval()
|
| 76 |
+
tokenizer = AutoTokenizer.from_pretrained("Zigeng/DMax-Math-16B", trust_remote_code=True)
|
| 77 |
+
|
| 78 |
+
prompt = "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?" + "\nLet's think step by step\n"
|
| 79 |
+
|
| 80 |
+
input_ids = tokenizer.apply_chat_template(
|
| 81 |
+
[{"role": "user", "content": prompt}],
|
| 82 |
+
add_generation_prompt=True,
|
| 83 |
+
tokenize=True,
|
| 84 |
+
return_tensors="pt",
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
nfe, generated_tokens = model.generate_spd(
|
| 88 |
+
inputs=input_ids,
|
| 89 |
+
gen_length=2048,
|
| 90 |
+
block_length=32,
|
| 91 |
+
threshold=0.0,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
generated_answer = tokenizer.decode(
|
| 95 |
+
generated_tokens[0],
|
| 96 |
+
skip_special_tokens=True,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
print(generated_answer)
|
| 100 |
+
print("nfe:",nfe,"token length",len(generated_tokens[0]))
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
## π Experimental Results
|
| 104 |
+
|
| 105 |
+

|
| 106 |
+
|
| 107 |
+
---
|
| 108 |
+
|