Text Generation
Transformers
Safetensors
llada2_moe
conversational
custom_code
Zigeng commited on
Commit
ff5c1d3
Β·
verified Β·
1 Parent(s): b7671b0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +108 -3
README.md CHANGED
@@ -1,3 +1,108 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - Zigeng/DMax-LLaDA-2.0-Mini-Math-Trajectories
5
+ base_model:
6
+ - inclusionAI/LLaDA2.0-mini
7
+ ---
8
+
9
+ <div align="center">
10
+ <h1>πŸš€ DMax: Aggressive Parallel Decoding for dLLMs</h1>
11
+ <div align="center">
12
+ <a href="https://github.com/czg1225/DMax/blob/main/LICENSE">
13
+ <img alt="Apache" src="https://img.shields.io/badge/License-Apache-4E94CE.svg">
14
+ </a>
15
+ <a href="https://github.com/czg1225/DMax">
16
+ <img src="https://img.shields.io/badge/Paper-Arxiv-darkred.svg" alt="Paper">
17
+ </a>
18
+ <a href="https://github.com/czg1225/DMax">
19
+ <img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub">
20
+ </a>
21
+ </div>
22
+ </div>
23
+
24
+ > **DMax: Aggressive Parallel Decoding for dLLMs**
25
+ > [Zigeng Chen](https://czg1225.github.io/chenzigeng99/), [Gongfan Fang](https://fangggf.github.io/), [Xinyin Ma](https://horseee.github.io/), [Ruonan Yu](https://scholar.google.com/citations?user=UHP95egAAAAJ&hl=en), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
26
+ > [xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore
27
+
28
+
29
+ ## πŸ’ͺ Highlights
30
+
31
+ - **Aggressive Decoding Parallelism**: Achieves 6.0 TPF on math and reasoning tasks and 6.6 TPF on code tasks while preserving accuracy.
32
+ - **Self-Revising dLLM**: Extends a pretrained MDLM into a UDLM with an intrinsic ability to revise its own erroneous predictions during decoding.
33
+ - **Soft Parallel Decoding**: Uses interpolation between mask and token embeddings to propagate confidence priors from previous steps.
34
+
35
+ <div align="center">
36
+ <img src="assets/tradeoff.png" width="90%" />
37
+ <br>
38
+ <em>Superior Parallelism-Accuracy Trade-off, Increased TPF with Maintained Accuracy.</em>
39
+ </div>
40
+
41
+ ## πŸ’‘ Introduction
42
+
43
+ We present DMax, a new paradigm for efficient dLLMs. It mitigates error accumulation in parallel decoding, enabling aggressive decoding parallelism while preserving generation quality. Unlike conventional masked dLLMs that decode through a binary mask-to-token transition, DMax reformulates decoding as a progressive self-refinement from mask embeddings to token embeddings. At the core of our approach is On-Policy Uniform Training, a novel training strategy that efficiently unifies masked and uniform dLLMs, equipping the model to recover clean tokens from both masked inputs and its own erroneous predictions. Building on this foundation, we further intoduce Soft Parallel Decoding. Extensive experiments across a variety of benchmarks demonstrate the effectiveness of DMax.
44
+
45
+ <!-- ![figure](assets/intro.png) -->
46
+ <div align="center">
47
+ <img src="assets/train.png" width="100%" />
48
+ <br>
49
+ <em>Overview of the On-Policy Uniform Training.</em>
50
+ </div>
51
+
52
+ ## πŸ’» Model and Datasets
53
+
54
+ | Model | Description | Source Model | Link |
55
+ | --- | --- | --- | --- |
56
+ | πŸ€– DMax-Math-16B | Highly parallel dLLM for math and reasoning. | LLaDA-2.0-mini | [Hugging Face](https://huggingface.co/Zigeng/DMax-Math-16B) |
57
+ | πŸ€– DMax-Coder-16B | Highly parallel dLLM for code generation. | LLaDA-2.0-mini | [Hugging Face](https://huggingface.co/Zigeng/DMax-Coder-16B) |
58
+
59
+ | Dataset | Description | Link |
60
+ | --- | --- | --- |
61
+ | πŸ“Š DMax-Math-Training-Data | Trajectories on math problems generated by LLaDA-2.0-mini | [Hugging Face](https://huggingface.co/datasets/Zigeng/DMax-LLaDA-2.0-Mini-Math-Trajectories) |
62
+ | πŸ“Š DMax-Code-Training-Data | Trajectories on code problems generated by LLaDA-2.0-mini | [Hugging Face](https://huggingface.co/datasets/Zigeng/DMax-LLaDA-2.0-Mini-Code-Trajectories) |
63
+
64
+ ## πŸš€ Quick Start
65
+
66
+ ```python
67
+ import torch
68
+ from transformers import AutoModelForCausalLM
69
+ from transformers import AutoTokenizer
70
+
71
+ model = AutoModelForCausalLM.from_pretrained(
72
+ "Zigeng/DMax-Math-16B", trust_remote_code=True, device_map="cuda:0"
73
+ )
74
+ model = model.to(torch.bfloat16)
75
+ model.eval()
76
+ tokenizer = AutoTokenizer.from_pretrained("Zigeng/DMax-Math-16B", trust_remote_code=True)
77
+
78
+ prompt = "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?" + "\nLet's think step by step\n"
79
+
80
+ input_ids = tokenizer.apply_chat_template(
81
+ [{"role": "user", "content": prompt}],
82
+ add_generation_prompt=True,
83
+ tokenize=True,
84
+ return_tensors="pt",
85
+ )
86
+
87
+ nfe, generated_tokens = model.generate_spd(
88
+ inputs=input_ids,
89
+ gen_length=2048,
90
+ block_length=32,
91
+ threshold=0.0,
92
+ )
93
+
94
+ generated_answer = tokenizer.decode(
95
+ generated_tokens[0],
96
+ skip_special_tokens=True,
97
+ )
98
+
99
+ print(generated_answer)
100
+ print("nfe:",nfe,"token length",len(generated_tokens[0]))
101
+ ```
102
+
103
+ ## πŸ“– Experimental Results
104
+
105
+ ![trade-off](assets/exp.png)
106
+
107
+ ---
108
+