Text Generation
Transformers
Safetensors
llada2_moe
conversational
custom_code
Zigeng commited on
Commit
e1e2401
Β·
verified Β·
1 Parent(s): 0eb1ed0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +97 -3
README.md CHANGED
@@ -1,3 +1,97 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - Zigeng/DMax-LLaDA-2.0-Mini-Code-Trajectories
5
+ base_model:
6
+ - inclusionAI/LLaDA2.0-mini
7
+ ---
8
+
9
+ <div align="center">
10
+ <h1>πŸš€ DMax: Aggressive Parallel Decoding for dLLMs</h1>
11
+ <div align="center">
12
+ <a href="https://github.com/czg1225/DMax/blob/main/LICENSE">
13
+ <img alt="Apache" src="https://img.shields.io/badge/License-Apache-4E94CE.svg">
14
+ </a>
15
+ <a href="https://github.com/czg1225/DMax">
16
+ <img src="https://img.shields.io/badge/Paper-Arxiv-darkred.svg" alt="Paper">
17
+ </a>
18
+ <a href="https://github.com/czg1225/DMax">
19
+ <img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub">
20
+ </a>
21
+ </div>
22
+ </div>
23
+
24
+ > **DMax: Aggressive Parallel Decoding for dLLMs**
25
+ > [Zigeng Chen](https://czg1225.github.io/chenzigeng99/), [Gongfan Fang](https://fangggf.github.io/), [Xinyin Ma](https://horseee.github.io/), [Ruonan Yu](https://scholar.google.com/citations?user=UHP95egAAAAJ&hl=en), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
26
+ > [xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore
27
+
28
+
29
+ ## πŸ’ͺ Highlights
30
+
31
+ - **Aggressive Decoding Parallelism**: Achieves 6.0 TPF on math and reasoning tasks and 6.6 TPF on code tasks while preserving accuracy.
32
+ - **Self-Revising dLLM**: Extends a pretrained MDLM into a UDLM with an intrinsic ability to revise its own erroneous predictions during decoding.
33
+ - **Soft Parallel Decoding**: Uses interpolation between mask and token embeddings to propagate confidence priors from previous steps.
34
+
35
+ <div align="center">
36
+ <img src="assets/tradeoff.png" width="90%" />
37
+ <br>
38
+ <em>Superior Parallelism-Accuracy Trade-off, Increased TPF with Maintained Accuracy.</em>
39
+ </div>
40
+
41
+
42
+ ## πŸ’» Model and Datasets
43
+
44
+ | Model | Description | Source Model | Link |
45
+ | --- | --- | --- | --- |
46
+ | πŸ€– DMax-Math-16B | Highly parallel dLLM for math and reasoning. | LLaDA-2.0-mini | [Hugging Face](https://huggingface.co/Zigeng/DMax-Math-16B) |
47
+ | πŸ€– DMax-Coder-16B | Highly parallel dLLM for code generation. | LLaDA-2.0-mini | [Hugging Face](https://huggingface.co/Zigeng/DMax-Coder-16B) |
48
+
49
+ | Dataset | Description | Link |
50
+ | --- | --- | --- |
51
+ | πŸ“Š DMax-Math-Training-Data | Trajectories on math problems generated by LLaDA-2.0-mini | [Hugging Face](https://huggingface.co/datasets/Zigeng/DMax-LLaDA-2.0-Mini-Math-Trajectories) |
52
+ | πŸ“Š DMax-Code-Training-Data | Trajectories on code problems generated by LLaDA-2.0-mini | [Hugging Face](https://huggingface.co/datasets/Zigeng/DMax-LLaDA-2.0-Mini-Code-Trajectories) |
53
+
54
+
55
+ ## πŸš€ Quick Start
56
+
57
+ ```python
58
+ import torch
59
+ from transformers import AutoModelForCausalLM
60
+ from transformers import AutoTokenizer
61
+
62
+ model = AutoModelForCausalLM.from_pretrained(
63
+ "Zigeng/DMax-Coder-16B", trust_remote_code=True, device_map="cuda:0"
64
+ )
65
+ model = model.to(torch.bfloat16)
66
+ model.eval()
67
+ tokenizer = AutoTokenizer.from_pretrained("Zigeng/DMax-Coder-16B", trust_remote_code=True)
68
+
69
+ prompt = "Write a python function to find the first repeated character in a given string." + "\n\nPlease enclose your code within delimiters as follows:\n```python\n# YOUR CODE HERE\n```\n\n"
70
+
71
+ input_ids = tokenizer.apply_chat_template(
72
+ [{"role": "user", "content": prompt}],
73
+ add_generation_prompt=True,
74
+ tokenize=True,
75
+ return_tensors="pt",
76
+ )
77
+
78
+ nfe, generated_tokens = model.generate_spd(
79
+ inputs=input_ids,
80
+ gen_length=2048,
81
+ block_length=32,
82
+ threshold=0.65,
83
+ )
84
+
85
+ generated_answer = tokenizer.decode(
86
+ generated_tokens[0],
87
+ skip_special_tokens=True,
88
+ )
89
+
90
+ print(generated_answer)
91
+ print("nfe:",nfe,"token length",len(generated_tokens[0]))
92
+ ```
93
+
94
+ ## πŸ“– Experimental Results
95
+
96
+ ![trade-off](assets/exp.png)
97
+