Chengyue Wu
commited on
Commit
·
ef88084
1
Parent(s):
393167c
update
Browse files- README.md +12 -14
- assets/benchmark_results.png +0 -0
- assets/throughput.png +0 -0
- modeling.py +6 -6
README.md
CHANGED
|
@@ -14,7 +14,7 @@ Autoregressive (AR) large language models (LLMs) have achieved remarkable perfor
|
|
| 14 |
|
| 15 |
Our approach introduces a novel decoding recipe incorporating a complementary attention mask and a position-aware masking strategy, which together enable blockwise bidirectional context modeling while preserving the original AR training objectives and performance. To further enhance inference speed, we design a hierarchical caching mechanism: a block-level cache that stores historical context representations and a token-level intra-block cache that supports efficient parallel decoding within partially generated blocks.
|
| 16 |
|
| 17 |
-
Coupled with our parallel decoding pipeline, Fast-dLLM v2 achieves a near
|
| 18 |
|
| 19 |
**This repo contains the Fast-dLLM v2 1.5B model**, which has the following features:
|
| 20 |
|
|
@@ -98,24 +98,22 @@ print(response)
|
|
| 98 |
|
| 99 |
Fast-dLLM v2 demonstrates state-of-the-art trade-offs between efficiency and performance among existing diffusion-based LLMs. The model achieves:
|
| 100 |
|
| 101 |
-
* Near
|
| 102 |
-
* Comparable generation quality to the
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
### Benchmark Results
|
| 106 |
|
| 107 |
-
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|-------|-----------|------------|------|-------|-------|------|--------|---------------|------|
|
| 111 |
-
| qwen2.5-1.5B-ar | 42.1 | 37.2 | 48.1 | 41.3 | 57.0 | 22.4 | 41.2 | 54.6 | 30.58 |
|
| 112 |
-
| Fast-dLLM-v2 | **43.3** | **40.2** | **50.0** | 41.3 | **60.1** | **28.4** | **45.7** | **55.1** | 27.7 |
|
| 113 |
|
| 114 |
-
**Key Observations:**
|
| 115 |
-
- Fast-dLLM v2 outperforms the base AR model on 7 out of 9 benchmarks
|
| 116 |
-
- Significant improvements in mathematical reasoning (MATH: 22.4 → 28.4) and instruction following (IFEval: 41.2 → 45.7)
|
| 117 |
-
- Comparable performance on MBPP+ and slight decrease on GPQA
|
| 118 |
-
- Overall performance improvement while achieving 4x inference speedup
|
| 119 |
|
| 120 |
## Citation
|
| 121 |
|
|
|
|
| 14 |
|
| 15 |
Our approach introduces a novel decoding recipe incorporating a complementary attention mask and a position-aware masking strategy, which together enable blockwise bidirectional context modeling while preserving the original AR training objectives and performance. To further enhance inference speed, we design a hierarchical caching mechanism: a block-level cache that stores historical context representations and a token-level intra-block cache that supports efficient parallel decoding within partially generated blocks.
|
| 16 |
|
| 17 |
+
Coupled with our parallel decoding pipeline, Fast-dLLM v2 achieves a near 2.5x speedup over standard AR decoding, without compromising generation quality. Extensive experiments demonstrate that Fast-dLLM v2 achieves state-of-the-art trade-offs between efficiency and performance among existing diffusion-based LLMs, marking a significant step toward practical deployment of fast and accurate language models.
|
| 18 |
|
| 19 |
**This repo contains the Fast-dLLM v2 1.5B model**, which has the following features:
|
| 20 |
|
|
|
|
| 98 |
|
| 99 |
Fast-dLLM v2 demonstrates state-of-the-art trade-offs between efficiency and performance among existing diffusion-based LLMs. The model achieves:
|
| 100 |
|
| 101 |
+
* Near 2.5x inference speedup compared to standard AR decoding
|
| 102 |
+
* Comparable generation quality to the original Qwen2.5-1.5B-Instruct model
|
| 103 |
+
|
| 104 |
+
### Throughput Performance
|
| 105 |
+
|
| 106 |
+
We accelerate the AR model with near 2.5x speedup at batch size 1.
|
| 107 |
+
For larger batch size, previous methods’ throughput will decrease while Fast-dLLM-v2 is consistently faster than AR.
|
| 108 |
+
|
| 109 |
+

|
| 110 |
|
| 111 |
### Benchmark Results
|
| 112 |
|
| 113 |
+
We well maintains the performance of AR-LLM and achieves the SOTA performance among 1B size LLM and also catch up the performance with 8B diffusion LLM (LLaDA).
|
| 114 |
|
| 115 |
+

|
|
|
|
|
|
|
|
|
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
## Citation
|
| 119 |
|
assets/benchmark_results.png
ADDED
|
assets/throughput.png
ADDED
|
modeling.py
CHANGED
|
@@ -163,13 +163,13 @@ class Fast_dLLM_QwenAttention(nn.Module):
|
|
| 163 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 164 |
key_states, value_states = block_past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 165 |
else:
|
| 166 |
-
block_cache_key_states = block_past_key_values[self.layer_idx][0]
|
| 167 |
-
block_cache_value_states = block_past_key_values[self.layer_idx][1]
|
| 168 |
|
| 169 |
block_cache_key_states[:, :, replace_position:replace_position+key_states.shape[2]] = key_states
|
| 170 |
block_cache_value_states[:, :, replace_position:replace_position+value_states.shape[2]] = value_states
|
| 171 |
-
key_states = block_cache_key_states
|
| 172 |
-
value_states = block_cache_value_states
|
| 173 |
|
| 174 |
if past_key_value is not None:
|
| 175 |
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
@@ -618,7 +618,7 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 618 |
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 619 |
logits = logits[:, start:end]
|
| 620 |
else:
|
| 621 |
-
logits = self.forward(input_ids=x_t[:,
|
| 622 |
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 623 |
else:
|
| 624 |
logits = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False).logits
|
|
@@ -629,7 +629,7 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 629 |
x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
|
| 630 |
# Select tokens with probability greater than threshold from p_1t
|
| 631 |
x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
|
| 632 |
-
x1_p = torch.where(mask_idx, x1_p, -torch.inf)
|
| 633 |
|
| 634 |
unmask_idx = (x1_p > threshold)
|
| 635 |
max_prob_idx = x1_p.argmax(dim=-1)
|
|
|
|
| 163 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 164 |
key_states, value_states = block_past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 165 |
else:
|
| 166 |
+
block_cache_key_states = block_past_key_values[self.layer_idx][0]
|
| 167 |
+
block_cache_value_states = block_past_key_values[self.layer_idx][1]
|
| 168 |
|
| 169 |
block_cache_key_states[:, :, replace_position:replace_position+key_states.shape[2]] = key_states
|
| 170 |
block_cache_value_states[:, :, replace_position:replace_position+value_states.shape[2]] = value_states
|
| 171 |
+
key_states = block_cache_key_states
|
| 172 |
+
value_states = block_cache_value_states
|
| 173 |
|
| 174 |
if past_key_value is not None:
|
| 175 |
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
|
|
| 618 |
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 619 |
logits = logits[:, start:end]
|
| 620 |
else:
|
| 621 |
+
logits = self.forward(input_ids=x_t[:,start:end], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True, block_past_key_values=block_past_key_values, replace_position=small_block_start_idx).logits
|
| 622 |
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 623 |
else:
|
| 624 |
logits = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False).logits
|
|
|
|
| 629 |
x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
|
| 630 |
# Select tokens with probability greater than threshold from p_1t
|
| 631 |
x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
|
| 632 |
+
x1_p = torch.where(mask_idx[:, start:end], x1_p, -torch.inf)
|
| 633 |
|
| 634 |
unmask_idx = (x1_p > threshold)
|
| 635 |
max_prob_idx = x1_p.argmax(dim=-1)
|