Zigeng commited on
Commit
71d408c
Β·
verified Β·
1 Parent(s): 8100777

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +139 -3
README.md CHANGED
@@ -1,3 +1,139 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+
5
+
6
+ <div align="center">
7
+ <h1>πŸš€ dParallel: Learnable Parallel Decoding for dLLMs</h1>
8
+ <div align="center">
9
+ <a href="https://opensource.org/license/mit-0">
10
+ <img alt="MIT" src="https://img.shields.io/badge/License-MIT-4E94CE.svg">
11
+ </a>
12
+ <a href="https://github.com/czg1225/dParallel">
13
+ <img src="https://img.shields.io/badge/Paper-Arxiv-darkred.svg" alt="Paper">
14
+ </a>
15
+ <a href="https://huggingface.co/Zigeng/dParallel-LLaDA-8b-instruct">
16
+ <img src="https://img.shields.io/badge/HuggingFace-Model-FFB000.svg" alt="Project">
17
+ </a>
18
+ <a href="https://huggingface.co/datasets/Zigeng/dParallel_LLaDA_Distill_Data">
19
+ <img src="https://img.shields.io/badge/HuggingFace-Data-FFB000.svg" alt="Project">
20
+ </a>
21
+ </div>
22
+ </div>
23
+
24
+ > **dParallel: Learnable Parallel Decoding for dLLMs**
25
+ > [Zigeng Chen](https://github.com/czg1225), [Gongfan Fang](https://fangggf.github.io/), [Xinyin Ma](https://horseee.github.io/), [Ruonan Yu](https://scholar.google.com/citations?user=UHP95egAAAAJ&hl=en), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
26
+ > [xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore
27
+
28
+
29
+ ## πŸ’‘ Introduction
30
+ We introduce dParallel, a simple and effective method that unlocks the inherent parallelism of dLLMs for fast sampling. We identify that the key bottleneck to parallel decoding arises from the sequential certainty convergence for masked tokens. Building on this insight, we introduce the core of our approach: certainty-forcing distillation, a novel training strategy that distills the model to follow its original sampling trajectories while enforcing it to achieve high certainty on masked tokens more rapidly and in parallel. Extensive experiments across various benchmarks demonstrate that our method can dramatically reduce the number of decoding steps while maintaining performance. When applied to the LLaDA-8B-Instruct model, dParallel reduces decoding steps from 256 to 30 on GSM8K, achieving an 8.5Γ— speedup without performance degradation. On the MBPP benchmark, it cuts decoding steps from 256 to 24, resulting in a 10.5Γ— speedup while maintaining accuracy.
31
+
32
+ <!-- ![figure](assets/intro.png) -->
33
+ <div align="center">
34
+ <img src="assets/method.png" width="100%" ></img>
35
+ <br>
36
+ <em>
37
+ Overview of proposed certainty-forcing distillation.
38
+ </em>
39
+ </div>
40
+ <br>
41
+
42
+
43
+
44
+ ## πŸ’» Model and Datasets
45
+ <table>
46
+ <table>
47
+ <thead>
48
+ </thead>
49
+ <tbody>
50
+ <tr>
51
+ <td>πŸ“„ <strong>Paper</strong></td>
52
+ <td><a href="https://github.com/czg1225/dParallel">ArXiv-Link</a></td>
53
+ </tr>
54
+ <tr>
55
+ <td>πŸ€– <strong>Model</strong></td>
56
+ <td><a href="https://huggingface.co/Zigeng/dParallel-LLaDA-8b-instruct">dParallel-LLaDA-8b-instruct</a></td>
57
+ </tr>
58
+ <tr>
59
+ <td>πŸ“Š <strong>Data</strong></td>
60
+ <td><a href="https://huggingface.co/datasets/Zigeng/dParallel_LLaDA_Distill_Data">
61
+ dParallel-LLaDA-Distill Dataset</a></td>
62
+ </tr>
63
+ </tbody>
64
+ </table>
65
+
66
+ ## πŸ”₯Updates
67
+ * πŸ”₯ **[Oct 2, 2025]**: Our arxiv paper is available.
68
+ * πŸ”₯ **[Oct 1, 2025]**: Code, model and dataset are released.
69
+
70
+ ## πŸ”§ Installation:
71
+
72
+ ```bash
73
+ conda create -n dparallel python==3.10
74
+ conda activate dparallel
75
+ pip3 install -r requirements.txt
76
+ ```
77
+
78
+ ## πŸš€ Quick Start:
79
+ ```python
80
+ from transformers import AutoTokenizer
81
+ from model.modeling_llada import LLaDAModelLM
82
+ from generate import generate
83
+ import torch
84
+
85
+ device = 'cuda'
86
+ model = LLaDAModelLM.from_pretrained('Zigeng/dParallel-LLaDA-8b-instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
87
+ tokenizer = AutoTokenizer.from_pretrained('Zigeng/dParallel-LLaDA-8b-instruct', trust_remote_code=True)
88
+
89
+ prompt = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Please reason step by step, and put your final answer within \\boxed{}."
90
+
91
+ m = [{"role": "user", "content": prompt}, ]
92
+ prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
93
+
94
+ input_ids = tokenizer(prompt)['input_ids']
95
+ input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
96
+
97
+ out = generate(model, input_ids, steps=256, gen_length=256, block_length=32, temperature=0., threshold=0.5,remasking='low_confidence')
98
+ print("Response:",tokenizer.batch_decode(out[0][:, input_ids.shape[1]:], skip_special_tokens=True)[0])
99
+ print("NFE:",out[1])
100
+ ```
101
+
102
+
103
+ ## πŸ”₯ Training
104
+ ### 1. Certainty-Forcing Distillation with LoRA:
105
+ We provide training scripts for our proposed Certainty-Forcing Distillation process. The implementation utilizes LoRA during the training process, with the configuration details specified in [config_lora_llada.yaml](https://github.com/czg1225/dParallel/blob/master/configs/config_lora_llada.yaml).
106
+ ```bash
107
+ deepspeed --master_port 29501 --include localhost:0,1,2,3 llada_train.py
108
+ ```
109
+
110
+ ### 2. LoRA Merge:
111
+ After training, merge the LoRA weights to get the dParallel-dLLM.
112
+ ```bash
113
+ python merge_lora.py
114
+ ```
115
+
116
+ ## ⚑ Evaluation:
117
+ We provide evaluation scripts for the GSM8K, Minerva_MATH, HumanEval, and MBPP benchmarks. Although our approach does not rely on caching or sparse attention techniques, it is fully compatible with them and can achieve even greater speedups when combined.
118
+ ```bash
119
+ sh eval.sh
120
+ ```
121
+
122
+
123
+ ## πŸ“– Experimental Results
124
+ ### Results on LLaDA-8B-Instruct:
125
+ ![llada-exp](assets/llada_exp.png)
126
+
127
+ ### Results on Dream-7B-Instruct:
128
+ ![dream-exp](assets/dream_exp.png)
129
+
130
+ ### Better Speed-Accuracy Trade-off:
131
+ ![trade-off](assets/trade-off.png)
132
+
133
+ ## β˜€οΈ Acknowledgement
134
+ Our code builds on [LLaDA](https://github.com/ML-GSAI/LLaDA), [Dream](https://github.com/DreamLM/Dream), [Fast-dLLM](https://github.com/NVlabs/Fast-dLLM/tree/main), and [dKV-Cache](https://github.com/horseee/dkv-cache), and we acknowledge these great works for laying the groundwork that made our approach possible.
135
+
136
+ ## Citation
137
+ If our research assists your work, please give us a star ⭐ or cite us using:
138
+ ```
139
+ ```