Text Generation
Transformers
Safetensors
llada2_moe
conversational
custom_code
File size: 3,947 Bytes
ff5c1d3
 
 
75d9570
 
 
 
 
ff5c1d3
 
 
 
 
 
 
 
75d9570
ff5c1d3
 
 
 
 
 
 
 
75d9570
ff5c1d3
75d9570
ff5c1d3
 
 
 
 
 
 
 
5544af9
ff5c1d3
 
 
 
 
 
 
 
 
5544af9
 
ff5c1d3
 
 
5a66bdc
 
ff5c1d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75d9570
 
 
ff5c1d3
 
 
 
 
 
 
 
 
 
 
 
702b27b
ff5c1d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75d9570
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
---
base_model:
- inclusionAI/LLaDA2.0-mini
datasets:
- Zigeng/DMax-LLaDA-2.0-Mini-Math-Trajectories
license: apache-2.0
library_name: transformers
pipeline_tag: text-generation
---

<div align="center">
<h1>πŸš€ DMax: Aggressive Parallel Decoding for dLLMs</h1>
  <div align="center">
  <a href="https://github.com/czg1225/DMax/blob/main/LICENSE">
    <img alt="Apache" src="https://img.shields.io/badge/License-Apache-4E94CE.svg">
  </a>
  <a href="https://arxiv.org/abs/2604.08302">
    <img src="https://img.shields.io/badge/Paper-Arxiv-darkred.svg" alt="Paper">
  </a>
  <a href="https://github.com/czg1225/DMax">
    <img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub">
  </a>
</div>
</div>

This repository contains the weights for **DMax-Math-16B**, presented in the paper [DMax: Aggressive Parallel Decoding for dLLMs](https://huggingface.co/papers/2604.08302).

DMax is a new paradigm for efficient diffusion language models (dLLMs) that mitigates error accumulation in parallel decoding, enabling aggressive decoding parallelism while preserving generation quality.

## πŸ’ͺ Highlights

- **Aggressive Decoding Parallelism**: Achieves 6.0 TPF on math and reasoning tasks and 6.6 TPF on code tasks while preserving accuracy.
- **Self-Revising dLLM**: Extends a pretrained MDLM into a UDLM with an intrinsic ability to revise its own erroneous predictions during decoding.
- **Soft Parallel Decoding**: Uses interpolation between mask and token embeddings to propagate confidence priors from previous steps.

<div align="center">
  <img src="assets/tradeoff.png" width="100%" />
  <br>
  <em>Superior Parallelism-Accuracy Trade-off, Increased TPF with Maintained Accuracy.</em>
</div>


## πŸ’» Model and Datasets

| Model | Description | Source Model | Link |
| --- | --- | --- | --- |
| πŸ€– DMax-Math-16B | Highly parallel dLLM for math and reasoning. | LLaDA-2.0-mini | [HF](https://huggingface.co/Zigeng/DMax-Math-16B) |
| πŸ€– DMax-Coder-16B | Highly parallel dLLM for code generation. | LLaDA-2.0-mini | [HF](https://huggingface.co/Zigeng/DMax-Coder-16B) |

| Dataset | Description | Link |
| --- | --- | --- |
| πŸ“Š DMax-Math-Training-Data | math trajectories generated by LLaDA-2.0-mini | [HF](https://huggingface.co/datasets/Zigeng/DMax-LLaDA-2.0-Mini-Math-Trajectories) |
| πŸ“Š DMax-Code-Training-Data | code trajectories generated by LLaDA-2.0-mini | [HF](https://huggingface.co/datasets/Zigeng/DMax-LLaDA-2.0-Mini-Code-Trajectories) |

## πŸš€ Quick Start

```python
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "Zigeng/DMax-Math-16B", trust_remote_code=True, device_map="cuda:0"
)
model = model.to(torch.bfloat16)
model.eval()
tokenizer = AutoTokenizer.from_pretrained("Zigeng/DMax-Math-16B", trust_remote_code=True)

prompt = "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?" + "
Let's think step by step
"

input_ids = tokenizer.apply_chat_template(
    [{"role": "user", "content": prompt}],
    add_generation_prompt=True,
    tokenize=True,
    return_tensors="pt",
)

nfe, generated_tokens = model.generate_spd(
    inputs=input_ids,
    gen_length=2048,
    block_length=32,
    threshold=0.5,
)

generated_answer = tokenizer.decode(
    generated_tokens[0],
    skip_special_tokens=True,
)

print(generated_answer)
print("nfe:",nfe,"token length",len(generated_tokens[0]))
```

## πŸ“– Experimental Results

![trade-off](assets/exp.png)

## πŸ“š Citation

```bibtex
@misc{chen2026dmaxaggressiveparalleldecoding,
      title={DMax: Aggressive Parallel Decoding for dLLMs}, 
      author={Zigeng Chen and Gongfan Fang and Xinyin Ma and Ruonan Yu and Xinchao Wang},
      year={2026},
      eprint={2604.08302},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2604.08302}, 
}
```