File size: 6,760 Bytes
71d408c e35a935 71d408c 008b152 71d408c b50451d 71d408c e35a935 71d408c e35a935 71d408c e35a935 71d408c 008b152 71d408c b50451d 71d408c e35a935 71d408c b50451d 71d408c e35a935 71d408c e35a935 71d408c e11ff4e e35a935 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
---
license: mit
pipeline_tag: text-generation
library_name: transformers
---
<div align="center">
<h1>π dParallel: Learnable Parallel Decoding for dLLMs</h1>
<div align="center">
<a href="https://opensource.org/license/mit-0">
<img alt="MIT" src="https://img.shields.io/badge/License-MIT-4E94CE.svg">
</a>
<a href="https://arxiv.org/pdf/2509.26488">
<img src="https://img.shields.io/badge/Paper-Arxiv-darkred.svg" alt="Paper">
</a>
<a href="https://huggingface.co/Zigeng/dParallel-LLaDA-8B-instruct">
<img src="https://img.shields.io/badge/HuggingFace-Model-FFB000.svg" alt="Project">
</a>
<a href="https://huggingface.co/datasets/Zigeng/dParallel_LLaDA_Distill_Data">
<img src="https://img.shields.io/badge/HuggingFace-Data-FFB000.svg" alt="Project">
</a>
<a href="https://github.com/czg1225/dParallel">
<img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub">
</a>
</div>
</div>
https://github.com/user-attachments/assets/89d81255-9cd8-46d1-886e-0733938e5328
> **dParallel: Learnable Parallel Decoding for dLLMs**
> [Zigeng Chen](https://github.com/czg1225), [Gongfan Fang](https://fangggf.github.io/), [Xinyin Ma](https://horseee.github.io/), [Ruonan Yu](https://scholar.google.com/citations?user=UHP95egAAAAJ&hl=en), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
> [xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore
## π‘ Introduction
We introduce dParallel, a simple and effective method that unlocks the inherent parallelism of dLLMs for fast sampling. We identify that the key bottleneck to parallel decoding arises from the sequential certainty convergence for masked tokens. Building on this insight, we introduce the core of our approach: certainty-forcing distillation, a novel training strategy that distills the model to follow its original sampling trajectories while enforcing it to achieve high certainty on masked tokens more rapidly and in parallel. Extensive experiments across various benchmarks demonstrate that our method can dramatically reduce the number of decoding steps while maintaining performance. When applied to the LLaDA-8B-Instruct model, dParallel reduces decoding steps from 256 to 30 on GSM8K, achieving an 8.5x speedup without performance degradation. On the MBPP benchmark, it cuts decoding steps from 256 to 24, resulting in a 10.5x speedup while maintaining accuracy.
<!--  -->
<div align="center">
<img src="assets/method.png" width="100%" ></img>
<br>
<em>
Overview of proposed certainty-forcing distillation.
</em>
</div>
<br>
## π» Model and Datasets
<table>
<table>
<thead>
</thead>
<tbody>
<tr>
<td>π <strong>Paper</strong></td>
<td><a href="https://arxiv.org/pdf/2509.26488">ArXiv-Link</a></td>
</tr>
<tr>
<td>π€ <strong>Model</strong></td>
<td><a href="https://huggingface.co/Zigeng/dParallel-LLaDA-8B-instruct">dParallel-LLaDA-8b-instruct</a></td>
</tr>
<tr>
<td>π <strong>Data</strong></td>
<td><a href="https://huggingface.co/datasets/Zigeng/dParallel_LLaDA_Distill_Data">
dParallel-LLaDA-Distill Dataset</a></td>
</tr>
</tbody>
</table>
## π₯Updates
* π₯ **[Oct 1, 2025]**: Our arxiv paper is available.
* π₯ **[Oct 1, 2025]**: Code, model and dataset are released.
## π§ Installation:
```bash
conda create -n dparallel python==3.10
conda activate dparallel
pip3 install -r requirements.txt
```
## π Quick Start:
```python
from transformers import AutoTokenizer
from model.modeling_llada import LLaDAModelLM
from generate import generate
import torch
device = 'cuda'
model = LLaDAModelLM.from_pretrained('Zigeng/dParallel-LLaDA-8B-instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('Zigeng/dParallel-LLaDA-8B-instruct', trust_remote_code=True)
prompt = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Please reason step by step, and put your final answer within \\boxed{}."
m = [{"role": "user", "content": prompt}, ]
prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
input_ids = tokenizer(prompt)['input_ids']
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
out = generate(model, input_ids, steps=256, gen_length=256, block_length=32, temperature=0., threshold=0.5,remasking='low_confidence')
print("Response:",tokenizer.batch_decode(out[0][:, input_ids.shape[1]:], skip_special_tokens=True)[0])
print("NFE:",out[1])
```
## β‘ Evaluation:
We provide evaluation scripts covering GSM8K, Minerva_MATH, HumanEval, and MBPP benchmarks. Importantly, both our reported results and the accompanying code are obtained without using caching or sparse attention techniques. Nevertheless, our method is fully compatible with these optimizations, and integrating them can yield even greater speedups.
```bash
sh eval.sh
```
## π₯ Training
### 1. Certainty-Forcing Distillation with LoRA:
We provide training scripts for our proposed Certainty-Forcing Distillation process. The implementation utilizes LoRA during the training process, with the configuration details specified in [config_lora_llada.yaml](https://github.com/czg1225/dParallel/blob/master/configs/config_lora_llada.yaml). The training can be completed with 24 GB memory GPUs.
```python
deepspeed --master_port 29501 --include localhost:0,1,2,3,4,5,6,7 llada_train.py
```
### 2. LoRA Merge:
After training, merge the LoRA weights to get the dParallel-dLLM.
```python
python merge_lora.py
```
## π Experimental Results
### Results on LLaDA-8B-Instruct:

### Results on Dream-7B-Instruct:

### Better Speed-Accuracy Trade-off:

## βοΈ Acknowledgement
Our code builds on [LLaDA](https://github.com/ML-GSAI/LLaDA), [Dream](https://github.com/DreamLM/Dream), [Fast-dLLM](https://github.com/NVlabs/Fast-dLLM/tree/main), and [dKV-Cache](https://github.com/horseee/dkv-cache), and we acknowledge these great works for laying the groundwork that made our approach possible.
## Citation
If our research assists your work, please give us a star β or cite us using:
```
@misc{chen2025dparallellearnableparalleldecoding,
title={dParallel: Learnable Parallel Decoding for dLLMs},
author={Zigeng Chen and Gongfan Fang and Xinyin Ma and Ruonan Yu and Xinchao Wang},
year={2025},
eprint={2509.26488},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2509.26488},
}
``` |