Chengyue Wu commited on
Commit
ef88084
·
1 Parent(s): 393167c
README.md CHANGED
@@ -14,7 +14,7 @@ Autoregressive (AR) large language models (LLMs) have achieved remarkable perfor
14
 
15
  Our approach introduces a novel decoding recipe incorporating a complementary attention mask and a position-aware masking strategy, which together enable blockwise bidirectional context modeling while preserving the original AR training objectives and performance. To further enhance inference speed, we design a hierarchical caching mechanism: a block-level cache that stores historical context representations and a token-level intra-block cache that supports efficient parallel decoding within partially generated blocks.
16
 
17
- Coupled with our parallel decoding pipeline, Fast-dLLM v2 achieves a near 4x speedup over standard AR decoding, without compromising generation quality. Extensive experiments demonstrate that Fast-dLLM v2 achieves state-of-the-art trade-offs between efficiency and performance among existing diffusion-based LLMs, marking a significant step toward practical deployment of fast and accurate language models.
18
 
19
  **This repo contains the Fast-dLLM v2 1.5B model**, which has the following features:
20
 
@@ -98,24 +98,22 @@ print(response)
98
 
99
  Fast-dLLM v2 demonstrates state-of-the-art trade-offs between efficiency and performance among existing diffusion-based LLMs. The model achieves:
100
 
101
- * Near 4x inference speedup compared to standard AR decoding
102
- * Comparable generation quality to the base Qwen2.5-1.5B-Instruct model
103
- * Efficient memory usage through hierarchical caching mechanisms
 
 
 
 
 
 
104
 
105
  ### Benchmark Results
106
 
107
- The following table compares the performance of Fast-dLLM-v2 against the base autoregressive model (qwen2.5-1.5B-ar) across various benchmarks:
108
 
109
- | Model | HumanEval | HumanEval+ | MBPP | MBPP+ | GSM8K | MATH | IFEval | MMLU (0-shot) | GPQA |
110
- |-------|-----------|------------|------|-------|-------|------|--------|---------------|------|
111
- | qwen2.5-1.5B-ar | 42.1 | 37.2 | 48.1 | 41.3 | 57.0 | 22.4 | 41.2 | 54.6 | 30.58 |
112
- | Fast-dLLM-v2 | **43.3** | **40.2** | **50.0** | 41.3 | **60.1** | **28.4** | **45.7** | **55.1** | 27.7 |
113
 
114
- **Key Observations:**
115
- - Fast-dLLM v2 outperforms the base AR model on 7 out of 9 benchmarks
116
- - Significant improvements in mathematical reasoning (MATH: 22.4 → 28.4) and instruction following (IFEval: 41.2 → 45.7)
117
- - Comparable performance on MBPP+ and slight decrease on GPQA
118
- - Overall performance improvement while achieving 4x inference speedup
119
 
120
  ## Citation
121
 
 
14
 
15
  Our approach introduces a novel decoding recipe incorporating a complementary attention mask and a position-aware masking strategy, which together enable blockwise bidirectional context modeling while preserving the original AR training objectives and performance. To further enhance inference speed, we design a hierarchical caching mechanism: a block-level cache that stores historical context representations and a token-level intra-block cache that supports efficient parallel decoding within partially generated blocks.
16
 
17
+ Coupled with our parallel decoding pipeline, Fast-dLLM v2 achieves a near 2.5x speedup over standard AR decoding, without compromising generation quality. Extensive experiments demonstrate that Fast-dLLM v2 achieves state-of-the-art trade-offs between efficiency and performance among existing diffusion-based LLMs, marking a significant step toward practical deployment of fast and accurate language models.
18
 
19
  **This repo contains the Fast-dLLM v2 1.5B model**, which has the following features:
20
 
 
98
 
99
  Fast-dLLM v2 demonstrates state-of-the-art trade-offs between efficiency and performance among existing diffusion-based LLMs. The model achieves:
100
 
101
+ * Near 2.5x inference speedup compared to standard AR decoding
102
+ * Comparable generation quality to the original Qwen2.5-1.5B-Instruct model
103
+
104
+ ### Throughput Performance
105
+
106
+ We accelerate the AR model with near 2.5x speedup at batch size 1.
107
+ For larger batch size, previous methods’ throughput will decrease while Fast-dLLM-v2 is consistently faster than AR.
108
+
109
+ ![Throughput Comparison](assets/throughput.png)
110
 
111
  ### Benchmark Results
112
 
113
+ We well maintains the performance of AR-LLM and achieves the SOTA performance among 1B size LLM and also catch up the performance with 8B diffusion LLM (LLaDA).
114
 
115
+ ![Benchmark Results](assets/benchmark_results.png)
 
 
 
116
 
 
 
 
 
 
117
 
118
  ## Citation
119
 
assets/benchmark_results.png ADDED
assets/throughput.png ADDED
modeling.py CHANGED
@@ -163,13 +163,13 @@ class Fast_dLLM_QwenAttention(nn.Module):
163
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
164
  key_states, value_states = block_past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
165
  else:
166
- block_cache_key_states = block_past_key_values[self.layer_idx][0].clone()
167
- block_cache_value_states = block_past_key_values[self.layer_idx][1].clone()
168
 
169
  block_cache_key_states[:, :, replace_position:replace_position+key_states.shape[2]] = key_states
170
  block_cache_value_states[:, :, replace_position:replace_position+value_states.shape[2]] = value_states
171
- key_states = block_cache_key_states.contiguous()
172
- value_states = block_cache_value_states.contiguous()
173
 
174
  if past_key_value is not None:
175
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
@@ -618,7 +618,7 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
618
  logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
619
  logits = logits[:, start:end]
620
  else:
621
- logits = self.forward(input_ids=x_t[:, -block_size+small_block_start_idx:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True, block_past_key_values=block_past_key_values, replace_position=small_block_start_idx).logits
622
  logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
623
  else:
624
  logits = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False).logits
@@ -629,7 +629,7 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
629
  x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
630
  # Select tokens with probability greater than threshold from p_1t
631
  x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
632
- x1_p = torch.where(mask_idx, x1_p, -torch.inf)
633
 
634
  unmask_idx = (x1_p > threshold)
635
  max_prob_idx = x1_p.argmax(dim=-1)
 
163
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
164
  key_states, value_states = block_past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
165
  else:
166
+ block_cache_key_states = block_past_key_values[self.layer_idx][0]
167
+ block_cache_value_states = block_past_key_values[self.layer_idx][1]
168
 
169
  block_cache_key_states[:, :, replace_position:replace_position+key_states.shape[2]] = key_states
170
  block_cache_value_states[:, :, replace_position:replace_position+value_states.shape[2]] = value_states
171
+ key_states = block_cache_key_states
172
+ value_states = block_cache_value_states
173
 
174
  if past_key_value is not None:
175
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
 
618
  logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
619
  logits = logits[:, start:end]
620
  else:
621
+ logits = self.forward(input_ids=x_t[:,start:end], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True, block_past_key_values=block_past_key_values, replace_position=small_block_start_idx).logits
622
  logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
623
  else:
624
  logits = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False).logits
 
629
  x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
630
  # Select tokens with probability greater than threshold from p_1t
631
  x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
632
+ x1_p = torch.where(mask_idx[:, start:end], x1_p, -torch.inf)
633
 
634
  unmask_idx = (x1_p > threshold)
635
  max_prob_idx = x1_p.argmax(dim=-1)