Text Generation
Transformers
Safetensors
PyTorch
nemotron_labs_diffusion
feature-extraction
nvidia
conversational
custom_code
Instructions to use nvidia/Nemotron-Labs-Diffusion-3B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nvidia/Nemotron-Labs-Diffusion-3B with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="nvidia/Nemotron-Labs-Diffusion-3B", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nvidia/Nemotron-Labs-Diffusion-3B", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use nvidia/Nemotron-Labs-Diffusion-3B with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "nvidia/Nemotron-Labs-Diffusion-3B" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/Nemotron-Labs-Diffusion-3B", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/nvidia/Nemotron-Labs-Diffusion-3B
- SGLang
How to use nvidia/Nemotron-Labs-Diffusion-3B with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "nvidia/Nemotron-Labs-Diffusion-3B" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/Nemotron-Labs-Diffusion-3B", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "nvidia/Nemotron-Labs-Diffusion-3B" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/Nemotron-Labs-Diffusion-3B", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use nvidia/Nemotron-Labs-Diffusion-3B with Docker Model Runner:
docker model run hf.co/nvidia/Nemotron-Labs-Diffusion-3B
Clean up rope params; ensure transformers 4.55/5.0 compatibility
#2
by abhgarg - opened
- .gitattributes +0 -5
- README.md +97 -97
- assets/demo.gif +0 -3
- assets/demo.mp4 +0 -3
- assets/result_acc.png +0 -3
- assets/result_efficiency.png +0 -3
- assets/teaser.png +0 -3
- chat_utils.py +313 -0
- config.json +21 -4
- configuration_nemotron_labs_diffusion.py → configuration_ministral_dlm.py +75 -8
- generation_config.json +1 -1
- linear_spec_lora/adapter_config.json +0 -34
- linear_spec_lora/adapter_model.safetensors +0 -3
- model_cards/bias.md +0 -4
- model_cards/explainability.md +0 -13
- model_cards/privacy.md +0 -11
- model_cards/safety.md +0 -6
- modeling_ministral.py +99 -7
- modeling_ministral_dlm.py +1860 -0
- modeling_nemotron_labs_diffusion.py +0 -870
.gitattributes
CHANGED
|
@@ -34,8 +34,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
-
assets/demo.gif filter=lfs diff=lfs merge=lfs -text
|
| 38 |
-
assets/demo.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
-
assets/result_acc.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
-
assets/result_efficiency.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
-
assets/teaser.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,160 +1,160 @@
|
|
| 1 |
---
|
| 2 |
library_name: transformers
|
| 3 |
-
|
| 4 |
-
license_name: nvidia-nemotron-open-model-license
|
| 5 |
-
license_link: >-
|
| 6 |
-
https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-nemotron-open-model-license/
|
| 7 |
-
pipeline_tag: text-generation
|
| 8 |
-
tags:
|
| 9 |
-
- nvidia
|
| 10 |
-
- pytorch
|
| 11 |
---
|
| 12 |
|
| 13 |
-
# Nemotron-
|
| 14 |
|
|
|
|
| 15 |
|
| 16 |
-
<div align="center" style="line-height: 1;">
|
| 17 |
-
<a href="https://d1qx31qr3h6wln.cloudfront.net/publications/Nemotron_Diffusion_Tech_Report_v1.pdf?VersionId=db8_EMO8B.vmU26.jr7Le9pN3MqcUDNL" target="_blank" style="margin: 2px;">
|
| 18 |
-
<img alt="Chat" src="https://img.shields.io/badge/📝Paper-Read Now!-536af5?color=76B900&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
|
| 19 |
-
</a>
|
| 20 |
-
<a href="https://huggingface.co/collections/nvidia/nemotron-labs-diffusion" target="_blank" style="margin: 2px;">
|
| 21 |
-
<img alt="Nemotron-Labs-Diffusion Model Family" src="https://img.shields.io/badge/%F0%9F%A4%97-Nemotron--Labs--Diffusion_Model_Family-76B900" style="display: inline-block; vertical-align: middle;"/>
|
| 22 |
-
</a>
|
| 23 |
-
<a href="https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-nemotron-open-model-license/" style="margin: 2px;">
|
| 24 |
-
<img alt="License" src="https://img.shields.io/badge/License-NVIDIA Open Model License-f5de53?&color=f5de53" style="display: inline-block; vertical-align: middle;"/>
|
| 25 |
-
</a>
|
| 26 |
-
</div>
|
| 27 |
|
|
|
|
| 28 |
|
| 29 |
-
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
</div>
|
| 39 |
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
- SOTA 3B, 8B, 14B dense LM family (base, instruct, and vision-language variants) supporting AR, diffusion, and self-speculation with the focus on decode efficiency.
|
| 44 |
-
- Generation moved from a memory-bound regime toward a compute-bound regime. Model weights are loaded once and reused to compute multiple tokens during generation.
|
| 45 |
-
- Self-speculation uses diffusion for drafting and AR for verification, providing a stronger alternative to MTP approaches:
|
| 46 |
-
* 3x higher acceptance length and 2.2x speed-up vs. Qwen3-8B-Eagle3 in SGLang.
|
| 47 |
-
* 5.9× tokens per forward over Qwen3-8B (no MTP) with the same accuracy.
|
| 48 |
-
- Real-device speed-up across platforms:
|
| 49 |
-
* DGX Spark (8B, concurrency 1): 2.7x faster with 112 tok/sec vs. 41.8 tok/sec AR using w4a16.
|
| 50 |
-
* GB200 (8B, concurrency 1): 3.3x faster with 850 tok/sec vs. 253 tok/sec AR and 360 tok/sec Eagle3. Custom CUDA kernels boost to 1015 tok/sec (4x).
|
| 51 |
-
- Diffusion speedup-of-light analysis shows that throughput can be further doubled (vs. current best) for a single user with better sampling - future research.
|
| 52 |
|
|
|
|
| 53 |
|
| 54 |
-
<div align="center">
|
| 55 |
-
<img src="./assets/result_acc.png" alt="Efficiency Results" width="800">
|
| 56 |
-
</div>
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
|
|
|
|
| 62 |
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
|
| 66 |
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
|
|
|
| 72 |
```
|
| 73 |
|
| 74 |
-
## Chat with Our Model
|
| 75 |
|
|
|
|
| 76 |
|
| 77 |
```
|
| 78 |
-
from transformers import AutoModel, AutoTokenizer
|
| 79 |
import torch
|
| 80 |
|
| 81 |
-
repo_name = "nvidia/Nemotron-
|
| 82 |
|
| 83 |
tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
history = []
|
| 88 |
|
| 89 |
user_input = input("User: ").strip()
|
| 90 |
history.append({"role": "user", "content": user_input})
|
| 91 |
|
| 92 |
-
prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
|
| 93 |
-
prompt_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device='cuda')
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
|
| 98 |
-
|
| 99 |
-
out_ids, nfe = model.generate(prompt_ids, max_new_tokens=512, block_length=32, threshold=0.9, eos_token_id=tokenizer.eos_token_id)
|
| 100 |
|
| 101 |
-
|
| 102 |
-
out_ids, nfe = model.linear_spec_generate(prompt_ids, max_new_tokens=512, block_length=32, eos_token_id=tokenizer.eos_token_id)
|
| 103 |
|
| 104 |
-
tokenized_out = tokenizer.batch_decode(out_ids[:, prompt_ids.shape[1]:], skip_special_tokens=True)[0]
|
| 105 |
print(f"Model: {tokenized_out}")
|
| 106 |
print(f"[Num Function Eval (NFE)={nfe}]")
|
| 107 |
```
|
| 108 |
|
|
|
|
| 109 |
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
-
|
| 112 |
|
| 113 |
-
|
|
|
|
|
|
|
| 114 |
|
|
|
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
from transformers import AutoModel, AutoTokenizer
|
| 119 |
-
from peft import PeftModel
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
model = model.cuda().to(torch.bfloat16)
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
history = [{"role": "user", "content": "Solve: What is 15% of 240?"}]
|
| 132 |
-
prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
|
| 133 |
-
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
|
| 134 |
-
|
| 135 |
-
out_ids, nfe = base.linear_spec_generate(
|
| 136 |
-
prompt_ids, max_new_tokens=512, block_length=32,
|
| 137 |
-
eos_token_id=tokenizer.eos_token_id,
|
| 138 |
-
)
|
| 139 |
-
print(tokenizer.decode(out_ids[0, prompt_ids.shape[1]:], skip_special_tokens=True))
|
| 140 |
-
print(f"[NFE={nfe}]")
|
| 141 |
```
|
|
|
|
|
|
|
| 142 |
|
|
|
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
|
|
|
| 146 |
|
| 147 |
-
|
| 148 |
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
|
|
|
|
|
|
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
author = {Yonggan Fu and Lexington Whalen and Abhinav Garg and Chengyue Wu and Maksim Khadkevich and Nicolai Oswald and Enze Xie and Daniel Egert and Sharath Turuvekere Sreenivas and Shizhe Diao and Chenhan Yu and Ye Yu and Weijia Chen and Sajad Norouzi and Shiyi Lan and Ligeng Zhu and Jin Wang and Jindong Jiang and Morteza Mardani and Mehran Maghoumi and Song Han and Ante Jukic and Nima Tajbakhsh and Jan Kautz and Pavlo Molchanov},
|
| 156 |
-
institution = {NVIDIA},
|
| 157 |
-
year = {2026},
|
| 158 |
-
note = {Technical report}
|
| 159 |
-
}
|
| 160 |
```
|
|
|
|
|
|
| 1 |
---
|
| 2 |
library_name: transformers
|
| 3 |
+
tags: []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
---
|
| 5 |
|
| 6 |
+
# Nemotron-Diffusion-Exp-Ministral-3B-Instruct
|
| 7 |
|
| 8 |
+
Developed by [DLER team](https://nv-dler.github.io/) @ NVR and will be updated actively. Contact Yonggan Fu and Pavlo Molchanov for any question.
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
# Environment
|
| 12 |
|
| 13 |
+
Docker path: `/lustre/fsw/portfolios/nvr/users/yongganf/docker/megatron_py25_dllm_ministral.sqsh` on CW-DFW. Apply for interactive nodes with the following command:
|
| 14 |
|
| 15 |
+
```
|
| 16 |
+
srun -A {account} --partition interactive --time 4:00:00 --gpus 8 --container-image /lustre/fsw/portfolios/nvr/users/yongganf/docker/megatron_py25_dllm_ministral.sqsh --container-mounts=$HOME:/home,/lustre:/lustre --pty bash
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
## Chat with Our Model in dLM Mode
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
```
|
| 23 |
+
from transformers import AutoModel, AutoTokenizer
|
| 24 |
+
import torch
|
| 25 |
|
| 26 |
+
repo_name = "nvidia/Nemotron-Diffusion-Exp-Ministral-3B-Instruct"
|
| 27 |
+
|
| 28 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
|
| 29 |
+
model = AutoModel.from_pretrained(repo_name, trust_remote_code=True)
|
| 30 |
+
model = model.cuda().to(torch.bfloat16)
|
| 31 |
|
| 32 |
+
history = []
|
| 33 |
|
| 34 |
+
user_input = input("User: ").strip()
|
| 35 |
+
history.append({"role": "user", "content": user_input})
|
|
|
|
| 36 |
|
| 37 |
+
prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
|
| 38 |
+
prompt_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device='cuda')
|
| 39 |
+
out_ids, nfe = model.generate(prompt_ids, max_new_tokens=512, steps=512, block_length=32, shift_logits=False, causal_context=True, threshold=0.9, eos_token_id=tokenizer.eos_token_id)
|
| 40 |
|
| 41 |
+
tokenized_out = tokenizer.batch_decode(out_ids[:, prompt_ids.shape[1]:], skip_special_tokens=True)[0]
|
| 42 |
+
print(f"Model: {tokenized_out}")
|
| 43 |
+
print(f"[Num Function Eval (NFE)={nfe}]")
|
| 44 |
+
```
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
## Chat with Our Model in AR Mode
|
| 48 |
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
```
|
| 51 |
+
from transformers import AutoModel, AutoTokenizer
|
| 52 |
+
import torch
|
| 53 |
|
| 54 |
+
repo_name = "nvidia/Nemotron-Diffusion-Exp-Ministral-3B-Instruct"
|
| 55 |
|
| 56 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
|
| 57 |
+
model = AutoModel.from_pretrained(repo_name, trust_remote_code=True)
|
| 58 |
+
model = model.cuda().to(torch.bfloat16)
|
| 59 |
|
| 60 |
+
history = []
|
| 61 |
|
| 62 |
+
user_input = input("User: ").strip()
|
| 63 |
+
history.append({"role": "user", "content": user_input})
|
| 64 |
|
| 65 |
+
prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
| 66 |
+
prompt_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device='cuda')
|
| 67 |
+
out_ids, nfe = model.ar_generate(inputs.input_ids, max_new_tokens=512)
|
| 68 |
|
| 69 |
+
tokenized_out = tokenizer.batch_decode(out_ids[:, prompt_ids.shape[1]:], skip_special_tokens=True)[0]
|
| 70 |
+
print(f"Model: {tokenized_out}")
|
| 71 |
+
print(f"[Num Function Eval (NFE)={nfe}]")
|
| 72 |
```
|
| 73 |
|
|
|
|
| 74 |
|
| 75 |
+
## Chat with Our Model in Quadratic Self-Speculation Mode
|
| 76 |
|
| 77 |
```
|
| 78 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
| 79 |
import torch
|
| 80 |
|
| 81 |
+
repo_name = "nvidia/Nemotron-Diffusion-Exp-Ministral-3B-Instruct"
|
| 82 |
|
| 83 |
tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
|
| 84 |
+
|
| 85 |
+
config = AutoConfig.from_pretrained(repo_name, trust_remote_code=True)
|
| 86 |
+
config.enable_self_spec = True
|
| 87 |
+
|
| 88 |
+
model = AutoModel.from_pretrained(repo_name, config=config, trust_remote_code=True).cuda().to(torch.bfloat16)
|
| 89 |
|
| 90 |
history = []
|
| 91 |
|
| 92 |
user_input = input("User: ").strip()
|
| 93 |
history.append({"role": "user", "content": user_input})
|
| 94 |
|
| 95 |
+
prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
|
|
|
| 96 |
|
| 97 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 98 |
+
inputs = inputs.to("cuda")
|
| 99 |
|
| 100 |
+
out_ids, nfe = model.self_spec_generate(inputs.input_ids, max_new_tokens=512, steps=512, block_length=32, ar_mix_weight=0.5, eos_token_id=tokenizer.eos_token_id)
|
|
|
|
| 101 |
|
| 102 |
+
tokenized_out = tokenizer.batch_decode(out_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
|
|
|
|
| 103 |
|
|
|
|
| 104 |
print(f"Model: {tokenized_out}")
|
| 105 |
print(f"[Num Function Eval (NFE)={nfe}]")
|
| 106 |
```
|
| 107 |
|
| 108 |
+
## Chat with Our Model in Linear Self-Speculation Mode
|
| 109 |
|
| 110 |
+
```
|
| 111 |
+
from transformers import AutoModel, AutoTokenizer
|
| 112 |
+
import torch
|
| 113 |
|
| 114 |
+
repo_name = "nvidia/Nemotron-Diffusion-Exp-Ministral-3B-Instruct"
|
| 115 |
|
| 116 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
|
| 117 |
+
model = AutoModel.from_pretrained(repo_name, trust_remote_code=True)
|
| 118 |
+
model = model.cuda().to(torch.bfloat16)
|
| 119 |
|
| 120 |
+
history = []
|
| 121 |
|
| 122 |
+
user_input = input("User: ").strip()
|
| 123 |
+
history.append({"role": "user", "content": user_input})
|
|
|
|
|
|
|
| 124 |
|
| 125 |
+
prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
| 126 |
+
prompt_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device='cuda')
|
| 127 |
+
out_ids, nfe = model.linear_spec_generate(prompt_ids, max_new_tokens=512, block_length=32, eos_token_id=tokenizer.eos_token_id)
|
|
|
|
| 128 |
|
| 129 |
+
tokenized_out = tokenizer.batch_decode(out_ids[:, prompt_ids.shape[1]:], skip_special_tokens=True)[0]
|
| 130 |
+
print(f"Model: {tokenized_out}")
|
| 131 |
+
print(f"[Num Function Eval (NFE)={nfe}]")
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Chat with Our Model in Linear Decoding Mode with Multi-Path Verification
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
```
|
| 138 |
+
from transformers import AutoModel, AutoTokenizer
|
| 139 |
+
import torch
|
| 140 |
|
| 141 |
+
repo_name = "nvidia/Nemotron-Diffusion-Exp-Ministral-3B-Instruct"
|
| 142 |
|
| 143 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
|
| 144 |
+
model = AutoModel.from_pretrained(repo_name, trust_remote_code=True)
|
| 145 |
+
model = model.cuda().to(torch.bfloat16)
|
| 146 |
|
| 147 |
+
history = []
|
| 148 |
|
| 149 |
+
user_input = input("User: ").strip()
|
| 150 |
+
history.append({"role": "user", "content": user_input})
|
| 151 |
|
| 152 |
+
prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
| 153 |
+
prompt_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device='cuda')
|
| 154 |
+
out_ids, nfe = model.linear_spec_generate_mp(prompt_ids, max_new_tokens=512, block_length=32, eos_token_id=tokenizer.eos_token_id)
|
| 155 |
|
| 156 |
+
tokenized_out = tokenizer.batch_decode(out_ids[:, prompt_ids.shape[1]:], skip_special_tokens=True)[0]
|
| 157 |
+
print(f"Model: {tokenized_out}")
|
| 158 |
+
print(f"[Num Function Eval (NFE)={nfe}]")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
```
|
| 160 |
+
|
assets/demo.gif
DELETED
Git LFS Details
|
assets/demo.mp4
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:666d8785ac4af75931d9c677757c4ef9945bf114d07f1c4e2ebb7b893ac39006
|
| 3 |
-
size 9454873
|
|
|
|
|
|
|
|
|
|
|
|
assets/result_acc.png
DELETED
Git LFS Details
|
assets/result_efficiency.png
DELETED
Git LFS Details
|
assets/teaser.png
DELETED
Git LFS Details
|
chat_utils.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def add_gumbel_noise(logits, temperature):
|
| 7 |
+
'''
|
| 8 |
+
The Gumbel max is a method for sampling categorical distributions.
|
| 9 |
+
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
|
| 10 |
+
Thus, we use float64.
|
| 11 |
+
'''
|
| 12 |
+
if temperature == 0:
|
| 13 |
+
return logits
|
| 14 |
+
logits = logits.to(torch.float64)
|
| 15 |
+
noise = torch.rand_like(logits, dtype=torch.float64)
|
| 16 |
+
gumbel_noise = (- torch.log(noise)) ** temperature
|
| 17 |
+
return logits.exp() / gumbel_noise
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False):
|
| 21 |
+
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
| 22 |
+
x0 = torch.argmax(logits_with_noise, dim=-1)
|
| 23 |
+
|
| 24 |
+
if remasking == 'low_confidence':
|
| 25 |
+
# p = F.softmax(logits.to(torch.float64), dim=-1)
|
| 26 |
+
p = F.softmax(logits, dim=-1)
|
| 27 |
+
x0_p = torch.squeeze(
|
| 28 |
+
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
|
| 29 |
+
elif remasking == 'top_p_margin':
|
| 30 |
+
# Compute probabilities
|
| 31 |
+
p = F.softmax(logits, dim=-1) # (B, L, V)
|
| 32 |
+
# Top-2 per position
|
| 33 |
+
top2 = torch.topk(p, k=2, dim=-1).values # (B, L, 2)
|
| 34 |
+
margin = top2[..., 0] - top2[..., 1] # (B, L)
|
| 35 |
+
|
| 36 |
+
# Normalize margin to [0,1] over MASKED positions per row
|
| 37 |
+
plus_inf = torch.full_like(margin, float('inf'))
|
| 38 |
+
minus_inf = torch.full_like(margin, float('-inf'))
|
| 39 |
+
masked_for_min = torch.where(mask_index, margin, plus_inf)
|
| 40 |
+
masked_for_max = torch.where(mask_index, margin, minus_inf)
|
| 41 |
+
row_min = masked_for_min.amin(dim=1, keepdim=True) # (B, 1)
|
| 42 |
+
row_max = masked_for_max.amax(dim=1, keepdim=True) # (B, 1)
|
| 43 |
+
denom = (row_max - row_min)
|
| 44 |
+
|
| 45 |
+
# If denom==0 (all equal), set normalized=1 on masked; 0 elsewhere by default
|
| 46 |
+
normalized = torch.zeros_like(margin)
|
| 47 |
+
nonzero = denom > 0
|
| 48 |
+
normalized = torch.where(
|
| 49 |
+
mask_index & nonzero,
|
| 50 |
+
(margin - row_min) / (denom + 1e-12),
|
| 51 |
+
normalized
|
| 52 |
+
)
|
| 53 |
+
normalized = torch.where(
|
| 54 |
+
mask_index & (~nonzero),
|
| 55 |
+
torch.ones_like(normalized),
|
| 56 |
+
normalized
|
| 57 |
+
)
|
| 58 |
+
x0_p = normalized # ∈ [0,1] on masked positions
|
| 59 |
+
elif remasking == 'random':
|
| 60 |
+
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
|
| 61 |
+
else:
|
| 62 |
+
raise NotImplementedError(remasking)
|
| 63 |
+
|
| 64 |
+
# Calculate negative entropy if requested
|
| 65 |
+
if neg_entropy:
|
| 66 |
+
# p = F.softmax(logits.to(torch.float64), dim=-1)
|
| 67 |
+
p = F.softmax(logits, dim=-1)
|
| 68 |
+
epsilon = 1e-10
|
| 69 |
+
log_probs = torch.log(p + epsilon)
|
| 70 |
+
confidence_scores = torch.sum(p * log_probs, dim=-1) # negative entropy per position
|
| 71 |
+
else:
|
| 72 |
+
confidence_scores = x0_p
|
| 73 |
+
|
| 74 |
+
x0 = torch.where(mask_index, x0, x)
|
| 75 |
+
confidence = torch.where(mask_index, confidence_scores, -np.inf)
|
| 76 |
+
|
| 77 |
+
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
|
| 78 |
+
if threshold is not None:
|
| 79 |
+
num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
|
| 80 |
+
# print(f'confidence: {confidence}')
|
| 81 |
+
for j in range(confidence.shape[0]):
|
| 82 |
+
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
|
| 83 |
+
transfer_index[j, select_index] = True
|
| 84 |
+
if threshold is not None:
|
| 85 |
+
for k in range(1, num_transfer_tokens[j]):
|
| 86 |
+
if confidence[j, select_index[k]] < threshold:
|
| 87 |
+
transfer_index[j, select_index[k]] = False
|
| 88 |
+
return x0, transfer_index
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_num_transfer_tokens(mask_index, steps: int):
|
| 92 |
+
mask_num = mask_index.sum(dim=1, keepdim=True)
|
| 93 |
+
base = mask_num // steps
|
| 94 |
+
remainder = mask_num % steps
|
| 95 |
+
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
|
| 96 |
+
for i in range(mask_num.size(0)):
|
| 97 |
+
num_transfer_tokens[i, : int(remainder[i])] += 1
|
| 98 |
+
return num_transfer_tokens
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
def generate_with_prefix_cache_block_diff(
|
| 103 |
+
model,
|
| 104 |
+
prompt,
|
| 105 |
+
steps=128,
|
| 106 |
+
gen_length=128,
|
| 107 |
+
block_length=128,
|
| 108 |
+
temperature=0.,
|
| 109 |
+
remasking='low_confidence',
|
| 110 |
+
mask_id=126336,
|
| 111 |
+
threshold=None,
|
| 112 |
+
factor=None,
|
| 113 |
+
shift_logits=False,
|
| 114 |
+
neg_entropy=False,
|
| 115 |
+
causal_context=False,
|
| 116 |
+
eos_token_id=None,
|
| 117 |
+
max_thinking_tokens=None,
|
| 118 |
+
end_think_token_id=None,
|
| 119 |
+
):
|
| 120 |
+
dream_style=shift_logits
|
| 121 |
+
x_accum = prompt.clone()
|
| 122 |
+
B = prompt.shape[0]
|
| 123 |
+
|
| 124 |
+
assert gen_length % block_length == 0
|
| 125 |
+
num_blocks = gen_length // block_length
|
| 126 |
+
|
| 127 |
+
assert steps % num_blocks == 0
|
| 128 |
+
steps_per_block = steps // num_blocks
|
| 129 |
+
|
| 130 |
+
nfe = 0
|
| 131 |
+
|
| 132 |
+
if causal_context:
|
| 133 |
+
model_module = model.module if hasattr(model, "module") else model
|
| 134 |
+
for layer in model_module.encoder.layers:
|
| 135 |
+
if hasattr(layer.self_attn, 'diffusion_lm'):
|
| 136 |
+
layer.self_attn.diffusion_lm=False
|
| 137 |
+
|
| 138 |
+
# Compute KV cache for the prompt initially
|
| 139 |
+
output = model(prompt, use_cache=True, use_causal_mask=causal_context)
|
| 140 |
+
past_key_values = output.past_key_values
|
| 141 |
+
|
| 142 |
+
if causal_context:
|
| 143 |
+
for layer in model_module.encoder.layers:
|
| 144 |
+
if hasattr(layer.self_attn, 'diffusion_lm'):
|
| 145 |
+
layer.self_attn.diffusion_lm=True
|
| 146 |
+
|
| 147 |
+
# Causal prefill: next token from last position (same as linear_spec_generate).
|
| 148 |
+
next_token = None
|
| 149 |
+
if causal_context:
|
| 150 |
+
last_logit = output.logits[:, -1, :]
|
| 151 |
+
if temperature > 0:
|
| 152 |
+
probs = torch.softmax(last_logit / temperature, dim=-1)
|
| 153 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 154 |
+
else:
|
| 155 |
+
next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
|
| 156 |
+
|
| 157 |
+
# For dream_style: store the "next token logit" of the context
|
| 158 |
+
next_logits_context = None
|
| 159 |
+
if dream_style:
|
| 160 |
+
next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
|
| 161 |
+
|
| 162 |
+
for num_block in range(num_blocks):
|
| 163 |
+
# Create a new block with mask tokens; under causal context, seed position 0
|
| 164 |
+
# with the next-token prediction from the previous causal forward (prefill or
|
| 165 |
+
# post-block encode), matching linear_spec_generate.
|
| 166 |
+
mask_block = torch.ones(
|
| 167 |
+
(prompt.shape[0], block_length),
|
| 168 |
+
dtype=prompt.dtype,
|
| 169 |
+
device=prompt.device
|
| 170 |
+
) * mask_id
|
| 171 |
+
if causal_context:
|
| 172 |
+
mask_block[:, 0] = next_token[:, 0]
|
| 173 |
+
|
| 174 |
+
# Append the block of masks
|
| 175 |
+
x_accum = torch.cat([x_accum, mask_block], dim=1)
|
| 176 |
+
current_block_start = prompt.size(1) + num_block * block_length
|
| 177 |
+
block_slice = slice(current_block_start, current_block_start + block_length)
|
| 178 |
+
|
| 179 |
+
# ---- thinking budget enforcement ----
|
| 180 |
+
# If we've generated >= max_thinking_tokens without a </think>, inject one.
|
| 181 |
+
if end_think_token_id is not None and max_thinking_tokens is not None:
|
| 182 |
+
tokens_before_block = num_block * block_length
|
| 183 |
+
tokens_after_block = tokens_before_block + block_length
|
| 184 |
+
if tokens_after_block > max_thinking_tokens:
|
| 185 |
+
gen_so_far = x_accum[:, prompt.size(1):current_block_start]
|
| 186 |
+
has_end_think = (
|
| 187 |
+
(gen_so_far == end_think_token_id).any(dim=1)
|
| 188 |
+
if gen_so_far.size(1) > 0
|
| 189 |
+
else torch.zeros(B, dtype=torch.bool, device=prompt.device)
|
| 190 |
+
)
|
| 191 |
+
if not has_end_think.all():
|
| 192 |
+
if tokens_before_block < max_thinking_tokens:
|
| 193 |
+
offset = max_thinking_tokens - tokens_before_block
|
| 194 |
+
else:
|
| 195 |
+
offset = 0
|
| 196 |
+
inject_pos = current_block_start + offset
|
| 197 |
+
for b in range(B):
|
| 198 |
+
if not has_end_think[b]:
|
| 199 |
+
x_accum[b, inject_pos] = end_think_token_id
|
| 200 |
+
|
| 201 |
+
# Build the initial mask for this block
|
| 202 |
+
mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb)
|
| 203 |
+
|
| 204 |
+
# Precompute the transfer schedule for this block
|
| 205 |
+
if dream_style:
|
| 206 |
+
# masked positions only (position 0 may be causal-seeded, not mask_id)
|
| 207 |
+
schedule_mask = mask_block_idx0
|
| 208 |
+
else:
|
| 209 |
+
schedule_mask = mask_block_idx0
|
| 210 |
+
|
| 211 |
+
num_transfer_tokens = get_num_transfer_tokens(schedule_mask, steps_per_block) # (B, steps)
|
| 212 |
+
|
| 213 |
+
# Denoise the current block
|
| 214 |
+
for i in range(steps_per_block):
|
| 215 |
+
mask_block_idx = (x_accum[:, block_slice] == mask_id) # (B, Lb)
|
| 216 |
+
if mask_block_idx.sum() == 0:
|
| 217 |
+
break
|
| 218 |
+
|
| 219 |
+
nfe += 1
|
| 220 |
+
|
| 221 |
+
# Forward only the current noisy block using cached context
|
| 222 |
+
logits_block = model(
|
| 223 |
+
x_accum[:, block_slice],
|
| 224 |
+
past_key_values=past_key_values,
|
| 225 |
+
use_cache=False
|
| 226 |
+
).logits
|
| 227 |
+
|
| 228 |
+
if dream_style:
|
| 229 |
+
# Align logits so that each masked position has a predictor:
|
| 230 |
+
# prepend context-next logit, then use logits_block[:-1]
|
| 231 |
+
if block_length == 1:
|
| 232 |
+
logits_use = next_logits_context # (B, 1, V)
|
| 233 |
+
else:
|
| 234 |
+
logits_use = torch.cat(
|
| 235 |
+
[next_logits_context, logits_block[:, :-1, :]],
|
| 236 |
+
dim=1
|
| 237 |
+
) # (B, Lb, V)
|
| 238 |
+
|
| 239 |
+
mask_use = mask_block_idx # (B, Lb)
|
| 240 |
+
x_use = x_accum[:, block_slice] # (B, Lb)
|
| 241 |
+
|
| 242 |
+
x0, transfer_idx = get_transfer_index(
|
| 243 |
+
logits_use, temperature, remasking, mask_use, x_use,
|
| 244 |
+
num_transfer_tokens=num_transfer_tokens[:, i],
|
| 245 |
+
threshold=threshold, neg_entropy=neg_entropy
|
| 246 |
+
)
|
| 247 |
+
cur = x_accum[:, block_slice].clone()
|
| 248 |
+
cur[transfer_idx] = x0[transfer_idx]
|
| 249 |
+
x_accum[:, block_slice] = cur
|
| 250 |
+
|
| 251 |
+
else:
|
| 252 |
+
# non-AR (same-position) case
|
| 253 |
+
x0, transfer_idx = get_transfer_index(
|
| 254 |
+
logits_block, temperature, remasking, mask_block_idx,
|
| 255 |
+
x_accum[:, block_slice],
|
| 256 |
+
num_transfer_tokens=num_transfer_tokens[:, i],
|
| 257 |
+
threshold=threshold, neg_entropy=neg_entropy
|
| 258 |
+
)
|
| 259 |
+
cur = x_accum[:, block_slice].clone()
|
| 260 |
+
cur[transfer_idx] = x0[transfer_idx]
|
| 261 |
+
x_accum[:, block_slice] = cur
|
| 262 |
+
|
| 263 |
+
if eos_token_id is not None:
|
| 264 |
+
block_tokens = x_accum[:, block_slice] # (B, Lb)
|
| 265 |
+
eos_mask = (block_tokens == eos_token_id) # (B, Lb)
|
| 266 |
+
any_eos = eos_mask.any(dim=1) # (B,)
|
| 267 |
+
if any_eos.any():
|
| 268 |
+
after_eos = eos_mask.cumsum(dim=1).bool() # (B, Lb)
|
| 269 |
+
mask_before = (block_tokens == mask_id) & ~after_eos
|
| 270 |
+
if (any_eos & ~mask_before.any(dim=1)).any():
|
| 271 |
+
break
|
| 272 |
+
|
| 273 |
+
if causal_context:
|
| 274 |
+
for layer in model_module.encoder.layers:
|
| 275 |
+
if hasattr(layer.self_attn, 'diffusion_lm'):
|
| 276 |
+
layer.self_attn.diffusion_lm=False
|
| 277 |
+
|
| 278 |
+
# after block is fully denoised, update KV cache
|
| 279 |
+
output = model(
|
| 280 |
+
x_accum[:, block_slice],
|
| 281 |
+
past_key_values=past_key_values,
|
| 282 |
+
use_cache=True,
|
| 283 |
+
use_causal_mask=causal_context
|
| 284 |
+
)
|
| 285 |
+
past_key_values = output.past_key_values
|
| 286 |
+
nfe += 1
|
| 287 |
+
|
| 288 |
+
if causal_context:
|
| 289 |
+
for layer in model_module.encoder.layers:
|
| 290 |
+
if hasattr(layer.self_attn, 'diffusion_lm'):
|
| 291 |
+
layer.self_attn.diffusion_lm=True
|
| 292 |
+
# Next block's first position = greedy/sampled next token from this causal encode
|
| 293 |
+
last_logit = output.logits[:, -1, :]
|
| 294 |
+
if temperature > 0:
|
| 295 |
+
probs = torch.softmax(last_logit / temperature, dim=-1)
|
| 296 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 297 |
+
else:
|
| 298 |
+
next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
|
| 299 |
+
|
| 300 |
+
if dream_style and num_block < num_blocks - 1:
|
| 301 |
+
# refresh context-next logit for the next block
|
| 302 |
+
next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
|
| 303 |
+
|
| 304 |
+
if eos_token_id is not None:
|
| 305 |
+
gen_so_far = x_accum[:, prompt.size(1):] # (B, gen_len_so_far)
|
| 306 |
+
is_eos = (gen_so_far == eos_token_id) # (B, gen_len_so_far)
|
| 307 |
+
has_eos = is_eos.any(dim=1) # (B,)
|
| 308 |
+
if has_eos.all():
|
| 309 |
+
first_eos_pos = is_eos.to(torch.int64).argmax(dim=1) # (B,)
|
| 310 |
+
max_eos = first_eos_pos.max().item()
|
| 311 |
+
return x_accum[:, : prompt.size(1) + max_eos + 1], nfe
|
| 312 |
+
|
| 313 |
+
return x_accum, nfe
|
config.json
CHANGED
|
@@ -1,21 +1,31 @@
|
|
| 1 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
"ar_loss_weight": 1.0,
|
| 3 |
"architectures": [
|
| 4 |
-
"
|
| 5 |
],
|
| 6 |
"attention_bias": false,
|
| 7 |
"attention_dropout": 0.0,
|
| 8 |
"attn_implementation": "sdpa",
|
| 9 |
"auto_map": {
|
| 10 |
-
"AutoConfig": "
|
| 11 |
-
"AutoModel": "
|
| 12 |
},
|
| 13 |
"block_size": 32,
|
| 14 |
"bos_token_id": 1,
|
|
|
|
|
|
|
| 15 |
"dlm_loss_weight": null,
|
| 16 |
"dlm_paradigm": "bidirectional",
|
|
|
|
| 17 |
"dp_varying_mask_ratio": false,
|
|
|
|
|
|
|
| 18 |
"eos_token_id": 11,
|
|
|
|
| 19 |
"head_dim": 128,
|
| 20 |
"hidden_act": "silu",
|
| 21 |
"hidden_size": 3072,
|
|
@@ -24,10 +34,16 @@
|
|
| 24 |
"mask_token_id": 100,
|
| 25 |
"max_position_embeddings": 262144,
|
| 26 |
"mlp_bias": false,
|
| 27 |
-
"model_type": "
|
|
|
|
|
|
|
| 28 |
"num_attention_heads": 32,
|
|
|
|
| 29 |
"num_hidden_layers": 26,
|
| 30 |
"num_key_value_heads": 8,
|
|
|
|
|
|
|
|
|
|
| 31 |
"rms_norm_eps": 1e-05,
|
| 32 |
"rope_parameters": {
|
| 33 |
"beta_fast": 32.0,
|
|
@@ -42,6 +58,7 @@
|
|
| 42 |
},
|
| 43 |
"sliding_window": null,
|
| 44 |
"tie_word_embeddings": false,
|
|
|
|
| 45 |
"torch_dtype": "bfloat16",
|
| 46 |
"transformers_version": "5.0.0",
|
| 47 |
"use_cache": false,
|
|
|
|
| 1 |
{
|
| 2 |
+
"ada_dlm_loss_ratio": null,
|
| 3 |
+
"ada_perm_ratio_global": null,
|
| 4 |
+
"ada_perm_ratio_per_block": null,
|
| 5 |
+
"adaptive_mask_rate": false,
|
| 6 |
"ar_loss_weight": 1.0,
|
| 7 |
"architectures": [
|
| 8 |
+
"MinistralDiffEncoderModel"
|
| 9 |
],
|
| 10 |
"attention_bias": false,
|
| 11 |
"attention_dropout": 0.0,
|
| 12 |
"attn_implementation": "sdpa",
|
| 13 |
"auto_map": {
|
| 14 |
+
"AutoConfig": "configuration_ministral_dlm.MinistralDLMConfig",
|
| 15 |
+
"AutoModel": "modeling_ministral_dlm.MinistralDiffEncoderModel"
|
| 16 |
},
|
| 17 |
"block_size": 32,
|
| 18 |
"bos_token_id": 1,
|
| 19 |
+
"diff_loss_weight": 1,
|
| 20 |
+
"dlm_arch": "encoder",
|
| 21 |
"dlm_loss_weight": null,
|
| 22 |
"dlm_paradigm": "bidirectional",
|
| 23 |
+
"dlm_type": "llada",
|
| 24 |
"dp_varying_mask_ratio": false,
|
| 25 |
+
"enable_self_spec": false,
|
| 26 |
+
"enforce_mask": false,
|
| 27 |
"eos_token_id": 11,
|
| 28 |
+
"global_loss_avg": false,
|
| 29 |
"head_dim": 128,
|
| 30 |
"hidden_act": "silu",
|
| 31 |
"hidden_size": 3072,
|
|
|
|
| 34 |
"mask_token_id": 100,
|
| 35 |
"max_position_embeddings": 262144,
|
| 36 |
"mlp_bias": false,
|
| 37 |
+
"model_type": "ministral_dlm",
|
| 38 |
+
"multi_sampling": null,
|
| 39 |
+
"num_ar_layers": 0,
|
| 40 |
"num_attention_heads": 32,
|
| 41 |
+
"num_diffusion_layers": 0,
|
| 42 |
"num_hidden_layers": 26,
|
| 43 |
"num_key_value_heads": 8,
|
| 44 |
+
"num_skip_loss_tokens": 0,
|
| 45 |
+
"prefix_ratio": 0.8,
|
| 46 |
+
"random_length_prob": 0,
|
| 47 |
"rms_norm_eps": 1e-05,
|
| 48 |
"rope_parameters": {
|
| 49 |
"beta_fast": 32.0,
|
|
|
|
| 58 |
},
|
| 59 |
"sliding_window": null,
|
| 60 |
"tie_word_embeddings": false,
|
| 61 |
+
"tok_mask_half_life_ratio": null,
|
| 62 |
"torch_dtype": "bfloat16",
|
| 63 |
"transformers_version": "5.0.0",
|
| 64 |
"use_cache": false,
|
configuration_nemotron_labs_diffusion.py → configuration_ministral_dlm.py
RENAMED
|
@@ -12,7 +12,7 @@
|
|
| 12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
-
"""
|
| 16 |
|
| 17 |
from transformers.configuration_utils import PretrainedConfig
|
| 18 |
from transformers.modeling_rope_utils import rope_config_validation
|
|
@@ -22,10 +22,10 @@ from transformers.utils import logging
|
|
| 22 |
logger = logging.get_logger(__name__)
|
| 23 |
|
| 24 |
|
| 25 |
-
class
|
| 26 |
r"""
|
| 27 |
-
This is the configuration class to store the configuration of a [`
|
| 28 |
-
It is used to instantiate a
|
| 29 |
|
| 30 |
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 31 |
documentation from [`PretrainedConfig`] for more information.
|
|
@@ -72,19 +72,52 @@ class NemotronLabsDiffusionConfig(PretrainedConfig):
|
|
| 72 |
Sliding window attention size.
|
| 73 |
mask_token_id (`int`, *optional*, defaults to -1):
|
| 74 |
Token ID for masking in diffusion.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
dlm_paradigm (`str`, *optional*, defaults to 'bidirectional'):
|
| 76 |
-
Paradigm for diffusion ('bidirectional', 'autoregressive', 'block_diff').
|
|
|
|
|
|
|
| 77 |
block_size (`int`, *optional*, defaults to 32):
|
| 78 |
Block size for block diffusion paradigms.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
dlm_loss_weight (`float`, *optional*):
|
| 80 |
Weight for diffusion LM loss.
|
| 81 |
ar_loss_weight (`float`, *optional*, defaults to 1.0):
|
| 82 |
-
Weight for autoregressive loss in
|
|
|
|
|
|
|
| 83 |
dp_varying_mask_ratio (`bool`, *optional*, defaults to False):
|
| 84 |
Whether to use varying mask ratio for each DP rank during sampling.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
"""
|
| 86 |
|
| 87 |
-
model_type = "
|
| 88 |
keys_to_ignore_at_inference = ["past_key_values"]
|
| 89 |
|
| 90 |
# Default tensor parallel plan for base model `Ministral`
|
|
@@ -129,11 +162,28 @@ class NemotronLabsDiffusionConfig(PretrainedConfig):
|
|
| 129 |
sliding_window=None,
|
| 130 |
attn_implementation="sdpa",
|
| 131 |
mask_token_id=-1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
dlm_paradigm='bidirectional',
|
|
|
|
| 133 |
block_size=32,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
dlm_loss_weight=None,
|
| 135 |
ar_loss_weight=1.0,
|
|
|
|
| 136 |
dp_varying_mask_ratio=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
**kwargs,
|
| 138 |
):
|
| 139 |
self.vocab_size = vocab_size
|
|
@@ -168,11 +218,28 @@ class NemotronLabsDiffusionConfig(PretrainedConfig):
|
|
| 168 |
self.attn_implementation = attn_implementation
|
| 169 |
|
| 170 |
self.mask_token_id = mask_token_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
self.dlm_paradigm = dlm_paradigm
|
|
|
|
| 172 |
self.block_size = block_size
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
self.dlm_loss_weight = dlm_loss_weight
|
| 174 |
self.ar_loss_weight = ar_loss_weight
|
|
|
|
| 175 |
self.dp_varying_mask_ratio = dp_varying_mask_ratio
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
super().__init__(
|
| 177 |
pad_token_id=pad_token_id,
|
| 178 |
bos_token_id=bos_token_id,
|
|
@@ -182,5 +249,5 @@ class NemotronLabsDiffusionConfig(PretrainedConfig):
|
|
| 182 |
)
|
| 183 |
|
| 184 |
|
| 185 |
-
__all__ = ["
|
| 186 |
|
|
|
|
| 12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
+
"""Ministral DLM model configuration"""
|
| 16 |
|
| 17 |
from transformers.configuration_utils import PretrainedConfig
|
| 18 |
from transformers.modeling_rope_utils import rope_config_validation
|
|
|
|
| 22 |
logger = logging.get_logger(__name__)
|
| 23 |
|
| 24 |
|
| 25 |
+
class MinistralDLMConfig(PretrainedConfig):
|
| 26 |
r"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`Ministral3Model`] for diffusion language models.
|
| 28 |
+
It is used to instantiate a Ministral model according to the specified arguments, defining the model architecture.
|
| 29 |
|
| 30 |
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 31 |
documentation from [`PretrainedConfig`] for more information.
|
|
|
|
| 72 |
Sliding window attention size.
|
| 73 |
mask_token_id (`int`, *optional*, defaults to -1):
|
| 74 |
Token ID for masking in diffusion.
|
| 75 |
+
dlm_type (`str`, *optional*, defaults to 'llada'):
|
| 76 |
+
Type of diffusion language model ('llada', 'dream').
|
| 77 |
+
random_length_prob (`float`, *optional*):
|
| 78 |
+
Probability of using random lengths during training.
|
| 79 |
+
num_ar_layers (`int`, *optional*, defaults to 0):
|
| 80 |
+
Number of autoregressive layers.
|
| 81 |
+
num_diffusion_layers (`int`, *optional*, defaults to 0):
|
| 82 |
+
Number of diffusion layers.
|
| 83 |
+
diff_loss_weight (`float`, *optional*, defaults to 1):
|
| 84 |
+
Weight for diffusion loss.
|
| 85 |
+
enforce_mask (`bool`, *optional*, defaults to False):
|
| 86 |
+
Whether to enforce masking.
|
| 87 |
+
prefix_ratio (`float`, *optional*, defaults to 0.8):
|
| 88 |
+
Ratio for prefix in prefix_bidirectional mode.
|
| 89 |
dlm_paradigm (`str`, *optional*, defaults to 'bidirectional'):
|
| 90 |
+
Paradigm for diffusion ('bidirectional', 'autoregressive', 'prefix_bidirectional', 'efficient_block_diff', 'block_diff', 'sbd_block_diff').
|
| 91 |
+
dlm_arch (`str`, *optional*, defaults to 'encoder'):
|
| 92 |
+
Architecture type ('encoder', 'encoder_decoder').
|
| 93 |
block_size (`int`, *optional*, defaults to 32):
|
| 94 |
Block size for block diffusion paradigms.
|
| 95 |
+
tok_mask_half_life_ratio (`float`, *optional*):
|
| 96 |
+
Half-life ratio for token masking.
|
| 97 |
+
adaptive_mask_rate (`bool`, *optional*, defaults to False):
|
| 98 |
+
Whether to use adaptive mask rate.
|
| 99 |
+
multi_sampling (`int`, *optional*):
|
| 100 |
+
Number of samples for multi-sampling.
|
| 101 |
+
num_skip_loss_tokens (`int`, *optional*, defaults to 0):
|
| 102 |
+
Number of tokens to skip in loss calculation.
|
| 103 |
dlm_loss_weight (`float`, *optional*):
|
| 104 |
Weight for diffusion LM loss.
|
| 105 |
ar_loss_weight (`float`, *optional*, defaults to 1.0):
|
| 106 |
+
Weight for autoregressive loss in sbd_block_diff paradigm. Use 10000 to only use AR loss.
|
| 107 |
+
global_loss_avg (`bool`, *optional*, defaults to False):
|
| 108 |
+
Whether to use global loss average.
|
| 109 |
dp_varying_mask_ratio (`bool`, *optional*, defaults to False):
|
| 110 |
Whether to use varying mask ratio for each DP rank during sampling.
|
| 111 |
+
ada_perm_ratio_per_block (`float`, *optional*):
|
| 112 |
+
Adaptive permutation ratio for each block.
|
| 113 |
+
ada_perm_ratio_global (`float`, *optional*):
|
| 114 |
+
Adaptive permutation ratio for global.
|
| 115 |
+
enable_self_spec (`bool`, *optional*, defaults to `False`):
|
| 116 |
+
Force MinistralFlexAttention for all paradigms (including bidirectional/autoregressive).
|
| 117 |
+
Required for self speculative generation; leave False for standard eval to use faster SDPA kernels.
|
| 118 |
"""
|
| 119 |
|
| 120 |
+
model_type = "ministral_dlm"
|
| 121 |
keys_to_ignore_at_inference = ["past_key_values"]
|
| 122 |
|
| 123 |
# Default tensor parallel plan for base model `Ministral`
|
|
|
|
| 162 |
sliding_window=None,
|
| 163 |
attn_implementation="sdpa",
|
| 164 |
mask_token_id=-1,
|
| 165 |
+
dlm_type='llada',
|
| 166 |
+
random_length_prob=None,
|
| 167 |
+
num_ar_layers=0,
|
| 168 |
+
num_diffusion_layers=0,
|
| 169 |
+
diff_loss_weight=1,
|
| 170 |
+
enforce_mask=False,
|
| 171 |
+
prefix_ratio=0.8,
|
| 172 |
dlm_paradigm='bidirectional',
|
| 173 |
+
dlm_arch='encoder',
|
| 174 |
block_size=32,
|
| 175 |
+
tok_mask_half_life_ratio=None,
|
| 176 |
+
adaptive_mask_rate=False,
|
| 177 |
+
multi_sampling=None,
|
| 178 |
+
num_skip_loss_tokens=0,
|
| 179 |
dlm_loss_weight=None,
|
| 180 |
ar_loss_weight=1.0,
|
| 181 |
+
global_loss_avg=False,
|
| 182 |
dp_varying_mask_ratio=False,
|
| 183 |
+
ada_perm_ratio_per_block=None,
|
| 184 |
+
ada_perm_ratio_global=None,
|
| 185 |
+
ada_dlm_loss_ratio=None,
|
| 186 |
+
enable_self_spec=False,
|
| 187 |
**kwargs,
|
| 188 |
):
|
| 189 |
self.vocab_size = vocab_size
|
|
|
|
| 218 |
self.attn_implementation = attn_implementation
|
| 219 |
|
| 220 |
self.mask_token_id = mask_token_id
|
| 221 |
+
self.dlm_type = dlm_type
|
| 222 |
+
self.random_length_prob = random_length_prob
|
| 223 |
+
self.num_ar_layers = num_ar_layers
|
| 224 |
+
self.num_diffusion_layers = num_diffusion_layers
|
| 225 |
+
self.diff_loss_weight = diff_loss_weight
|
| 226 |
+
self.enforce_mask = enforce_mask
|
| 227 |
+
self.prefix_ratio = prefix_ratio
|
| 228 |
self.dlm_paradigm = dlm_paradigm
|
| 229 |
+
self.dlm_arch = dlm_arch
|
| 230 |
self.block_size = block_size
|
| 231 |
+
self.tok_mask_half_life_ratio = tok_mask_half_life_ratio
|
| 232 |
+
self.adaptive_mask_rate = adaptive_mask_rate
|
| 233 |
+
self.multi_sampling = multi_sampling
|
| 234 |
+
self.num_skip_loss_tokens = num_skip_loss_tokens
|
| 235 |
self.dlm_loss_weight = dlm_loss_weight
|
| 236 |
self.ar_loss_weight = ar_loss_weight
|
| 237 |
+
self.global_loss_avg = global_loss_avg
|
| 238 |
self.dp_varying_mask_ratio = dp_varying_mask_ratio
|
| 239 |
+
self.ada_perm_ratio_per_block = ada_perm_ratio_per_block
|
| 240 |
+
self.ada_perm_ratio_global = ada_perm_ratio_global
|
| 241 |
+
self.ada_dlm_loss_ratio = ada_dlm_loss_ratio
|
| 242 |
+
self.enable_self_spec = enable_self_spec
|
| 243 |
super().__init__(
|
| 244 |
pad_token_id=pad_token_id,
|
| 245 |
bos_token_id=bos_token_id,
|
|
|
|
| 249 |
)
|
| 250 |
|
| 251 |
|
| 252 |
+
__all__ = ["MinistralDLMConfig"]
|
| 253 |
|
generation_config.json
CHANGED
|
@@ -2,6 +2,6 @@
|
|
| 2 |
"_from_model_config": true,
|
| 3 |
"bos_token_id": 1,
|
| 4 |
"eos_token_id": 11,
|
| 5 |
-
"transformers_version": "
|
| 6 |
"use_cache": false
|
| 7 |
}
|
|
|
|
| 2 |
"_from_model_config": true,
|
| 3 |
"bos_token_id": 1,
|
| 4 |
"eos_token_id": 11,
|
| 5 |
+
"transformers_version": "4.55.4",
|
| 6 |
"use_cache": false
|
| 7 |
}
|
linear_spec_lora/adapter_config.json
DELETED
|
@@ -1,34 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"alpha_pattern": {},
|
| 3 |
-
"auto_mapping": {
|
| 4 |
-
"base_model_class": "NemotronLabsDiffusionModel",
|
| 5 |
-
"parent_library": "transformers_modules.Nemotron-Labs-Diffusion-3B.modeling_nemotron_labs_diffusion"
|
| 6 |
-
},
|
| 7 |
-
"base_model_name_or_path": "nvidia/Nemotron-Labs-Diffusion-3B",
|
| 8 |
-
"bias": "none",
|
| 9 |
-
"eva_config": null,
|
| 10 |
-
"exclude_modules": null,
|
| 11 |
-
"fan_in_fan_out": false,
|
| 12 |
-
"inference_mode": true,
|
| 13 |
-
"init_lora_weights": true,
|
| 14 |
-
"layer_replication": null,
|
| 15 |
-
"layers_pattern": null,
|
| 16 |
-
"layers_to_transform": null,
|
| 17 |
-
"loftq_config": {},
|
| 18 |
-
"lora_alpha": 512,
|
| 19 |
-
"lora_bias": false,
|
| 20 |
-
"lora_dropout": 0.0,
|
| 21 |
-
"megatron_config": null,
|
| 22 |
-
"megatron_core": "megatron.core",
|
| 23 |
-
"modules_to_save": null,
|
| 24 |
-
"peft_type": "LORA",
|
| 25 |
-
"r": 128,
|
| 26 |
-
"rank_pattern": {},
|
| 27 |
-
"revision": null,
|
| 28 |
-
"target_modules": [
|
| 29 |
-
"o_proj"
|
| 30 |
-
],
|
| 31 |
-
"task_type": null,
|
| 32 |
-
"use_dora": false,
|
| 33 |
-
"use_rslora": false
|
| 34 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
linear_spec_lora/adapter_model.safetensors
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:897ef67dff8a69bd1a908fa390ef2164fdaa738e0e47bec502e2f0d86311ff74
|
| 3 |
-
size 95427600
|
|
|
|
|
|
|
|
|
|
|
|
model_cards/bias.md
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
Field | Response
|
| 2 |
-
:---------------------------------------------------------------------------------------------------|:---------------
|
| 3 |
-
Participation considerations from adversely impacted groups [protected classes](https://www.senate.ca.gov/content/protected-classes) in model design and testing: | [None]
|
| 4 |
-
Measures taken to mitigate against unwanted bias: | [None]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_cards/explainability.md
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
Field | Response
|
| 2 |
-
:------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------
|
| 3 |
-
Intended Task/Domain: | Text generation
|
| 4 |
-
Model Type: | Transformer
|
| 5 |
-
Intended Users: | Generative AI creators working with conversational AI models.
|
| 6 |
-
Output: | Text (Responds to posed question, Stateful - remembers previous answers)
|
| 7 |
-
Describe how the model works: | Text input is encoded into tokens and passed into a transformer-based language model, which returns a text response.
|
| 8 |
-
Name the adversely impacted groups this has been tested to deliver comparable outcomes regardless of: | Not Applicable
|
| 9 |
-
Technical Limitations & Mitigation: | The model cannot perform long-horizon reasoning and tool calling.
|
| 10 |
-
Verified to have met prescribed NVIDIA quality standards: | Yes
|
| 11 |
-
Performance Metrics: | Accuracy, Latency, Throughput
|
| 12 |
-
Potential Known Risks: | In some instances, the model may think too long and struggle to derive final answers. The model's output can generate all forms of text, including what may be considered toxic, offensive, or indecent.
|
| 13 |
-
Licensing: | nvidia-open-model-license.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_cards/privacy.md
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
Field | Response
|
| 2 |
-
:----------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------
|
| 3 |
-
Generatable or reverse engineerable personal data? | [No]
|
| 4 |
-
Personal data used to create this model? | [No]
|
| 5 |
-
Was consent obtained for any personal data used? | [Not Applicable]
|
| 6 |
-
How often is dataset reviewed? | [During dataset creation, model training, evaluation, and the prerelease phase.]
|
| 7 |
-
Was data from user interactions with the AI model (e.g. user input and prompts) used to train the model? | [Yes]
|
| 8 |
-
Is there provenance for all datasets used in training? | Yes
|
| 9 |
-
Does data labeling (annotation, metadata) comply with privacy laws? | Yes
|
| 10 |
-
Is data compliant with data subject requests for data correction or removal, if such a request was made? | Not Applicable.
|
| 11 |
-
Applicable Privacy Policy | https://www.nvidia.com/en-us/about-nvidia/privacy-policy/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_cards/safety.md
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
Field | Response
|
| 2 |
-
:---------------------------------------------------|:----------------------------------
|
| 3 |
-
Model Application Field(s): | [Media & Entertainment].
|
| 4 |
-
Describe the life critical impact (if present). | Not Applicable
|
| 5 |
-
Model and dataset restrictions: | The Principle of least privilege (PoLP) is applied limiting access for dataset generation and model development. Restrictions enforce dataset access during training, and dataset license constraints adhered to.
|
| 6 |
-
Use Case Restrictions: | Abide by nvidia-open-model-license.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modeling_ministral.py
CHANGED
|
@@ -25,7 +25,7 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
| 25 |
from transformers.processing_utils import Unpack
|
| 26 |
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
|
| 27 |
# from transformers.utils.generic import maybe_autocast
|
| 28 |
-
from .
|
| 29 |
|
| 30 |
#ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
|
| 31 |
|
|
@@ -110,7 +110,7 @@ def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_positi
|
|
| 110 |
class Ministral3Attention(nn.Module):
|
| 111 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 112 |
|
| 113 |
-
def __init__(self, config:
|
| 114 |
super().__init__()
|
| 115 |
self.config = config
|
| 116 |
self.layer_idx = layer_idx
|
|
@@ -234,7 +234,7 @@ class Ministral3RMSNorm(nn.Module):
|
|
| 234 |
|
| 235 |
|
| 236 |
class Ministral3DecoderLayer(GradientCheckpointingLayer):
|
| 237 |
-
def __init__(self, config:
|
| 238 |
super().__init__()
|
| 239 |
self.hidden_size = config.hidden_size
|
| 240 |
|
|
@@ -284,7 +284,7 @@ class Ministral3DecoderLayer(GradientCheckpointingLayer):
|
|
| 284 |
|
| 285 |
@auto_docstring
|
| 286 |
class Ministral3PreTrainedModel(PreTrainedModel):
|
| 287 |
-
config:
|
| 288 |
base_model_prefix = "model"
|
| 289 |
supports_gradient_checkpointing = True
|
| 290 |
_no_split_modules = ["Ministral3DecoderLayer"]
|
|
@@ -304,7 +304,7 @@ class Ministral3PreTrainedModel(PreTrainedModel):
|
|
| 304 |
class Ministral3RotaryEmbedding(nn.Module):
|
| 305 |
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 306 |
|
| 307 |
-
def __init__(self, config:
|
| 308 |
super().__init__()
|
| 309 |
self.max_seq_len_cached = config.max_position_embeddings
|
| 310 |
self.original_max_seq_len = config.max_position_embeddings
|
|
@@ -323,7 +323,7 @@ class Ministral3RotaryEmbedding(nn.Module):
|
|
| 323 |
|
| 324 |
@staticmethod
|
| 325 |
def compute_default_rope_parameters(
|
| 326 |
-
config: Optional[
|
| 327 |
device: Optional["torch.device"] = None,
|
| 328 |
seq_len: Optional[int] = None,
|
| 329 |
) -> tuple["torch.Tensor", float]:
|
|
@@ -370,7 +370,7 @@ class Ministral3RotaryEmbedding(nn.Module):
|
|
| 370 |
|
| 371 |
@auto_docstring
|
| 372 |
class Ministral3Model(Ministral3PreTrainedModel):
|
| 373 |
-
def __init__(self, config:
|
| 374 |
super().__init__(config)
|
| 375 |
self.padding_idx = config.pad_token_id
|
| 376 |
self.vocab_size = config.vocab_size
|
|
@@ -453,7 +453,99 @@ class Ministral3Model(Ministral3PreTrainedModel):
|
|
| 453 |
)
|
| 454 |
|
| 455 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
__all__ = [
|
|
|
|
|
|
|
| 457 |
"Ministral3Model",
|
| 458 |
"Ministral3PreTrainedModel",
|
|
|
|
|
|
|
| 459 |
]
|
|
|
|
| 25 |
from transformers.processing_utils import Unpack
|
| 26 |
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
|
| 27 |
# from transformers.utils.generic import maybe_autocast
|
| 28 |
+
from .configuration_ministral_dlm import MinistralDLMConfig
|
| 29 |
|
| 30 |
#ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
|
| 31 |
|
|
|
|
| 110 |
class Ministral3Attention(nn.Module):
|
| 111 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 112 |
|
| 113 |
+
def __init__(self, config: MinistralDLMConfig, layer_idx: int):
|
| 114 |
super().__init__()
|
| 115 |
self.config = config
|
| 116 |
self.layer_idx = layer_idx
|
|
|
|
| 234 |
|
| 235 |
|
| 236 |
class Ministral3DecoderLayer(GradientCheckpointingLayer):
|
| 237 |
+
def __init__(self, config: MinistralDLMConfig, layer_idx: int):
|
| 238 |
super().__init__()
|
| 239 |
self.hidden_size = config.hidden_size
|
| 240 |
|
|
|
|
| 284 |
|
| 285 |
@auto_docstring
|
| 286 |
class Ministral3PreTrainedModel(PreTrainedModel):
|
| 287 |
+
config: MinistralDLMConfig
|
| 288 |
base_model_prefix = "model"
|
| 289 |
supports_gradient_checkpointing = True
|
| 290 |
_no_split_modules = ["Ministral3DecoderLayer"]
|
|
|
|
| 304 |
class Ministral3RotaryEmbedding(nn.Module):
|
| 305 |
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 306 |
|
| 307 |
+
def __init__(self, config: MinistralDLMConfig, device=None):
|
| 308 |
super().__init__()
|
| 309 |
self.max_seq_len_cached = config.max_position_embeddings
|
| 310 |
self.original_max_seq_len = config.max_position_embeddings
|
|
|
|
| 323 |
|
| 324 |
@staticmethod
|
| 325 |
def compute_default_rope_parameters(
|
| 326 |
+
config: Optional[MinistralDLMConfig] = None,
|
| 327 |
device: Optional["torch.device"] = None,
|
| 328 |
seq_len: Optional[int] = None,
|
| 329 |
) -> tuple["torch.Tensor", float]:
|
|
|
|
| 370 |
|
| 371 |
@auto_docstring
|
| 372 |
class Ministral3Model(Ministral3PreTrainedModel):
|
| 373 |
+
def __init__(self, config: MinistralDLMConfig):
|
| 374 |
super().__init__(config)
|
| 375 |
self.padding_idx = config.pad_token_id
|
| 376 |
self.vocab_size = config.vocab_size
|
|
|
|
| 453 |
)
|
| 454 |
|
| 455 |
|
| 456 |
+
@auto_docstring
|
| 457 |
+
class Ministral3ForCausalLM(Ministral3PreTrainedModel, GenerationMixin):
|
| 458 |
+
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
| 459 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 460 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 461 |
+
|
| 462 |
+
def __init__(self, config):
|
| 463 |
+
super().__init__(config)
|
| 464 |
+
self.model = Ministral3Model(config)
|
| 465 |
+
self.vocab_size = config.vocab_size
|
| 466 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 467 |
+
|
| 468 |
+
# Initialize weights and apply final processing
|
| 469 |
+
self.post_init()
|
| 470 |
+
|
| 471 |
+
@can_return_tuple
|
| 472 |
+
@auto_docstring
|
| 473 |
+
def forward(
|
| 474 |
+
self,
|
| 475 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 476 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 477 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 478 |
+
past_key_values: Optional[Cache] = None,
|
| 479 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 480 |
+
labels: Optional[torch.LongTensor] = None,
|
| 481 |
+
use_cache: Optional[bool] = None,
|
| 482 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 483 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 484 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 485 |
+
) -> CausalLMOutputWithPast:
|
| 486 |
+
r"""
|
| 487 |
+
Example:
|
| 488 |
+
|
| 489 |
+
```python
|
| 490 |
+
>>> from transformers import AutoTokenizer, Ministral3ForCausalLM
|
| 491 |
+
|
| 492 |
+
>>> model = Ministral3ForCausalLM.from_pretrained("meta-ministral3/Ministral3-2-7b-hf")
|
| 493 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-ministral3/Ministral3-2-7b-hf")
|
| 494 |
+
|
| 495 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 496 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 497 |
+
|
| 498 |
+
>>> # Generate
|
| 499 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 500 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 501 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 502 |
+
```"""
|
| 503 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 504 |
+
input_ids=input_ids,
|
| 505 |
+
attention_mask=attention_mask,
|
| 506 |
+
position_ids=position_ids,
|
| 507 |
+
past_key_values=past_key_values,
|
| 508 |
+
inputs_embeds=inputs_embeds,
|
| 509 |
+
use_cache=use_cache,
|
| 510 |
+
cache_position=cache_position,
|
| 511 |
+
**kwargs,
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
hidden_states = outputs.last_hidden_state
|
| 515 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 516 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 517 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 518 |
+
|
| 519 |
+
loss = None
|
| 520 |
+
if labels is not None:
|
| 521 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 522 |
+
|
| 523 |
+
return CausalLMOutputWithPast(
|
| 524 |
+
loss=loss,
|
| 525 |
+
logits=logits,
|
| 526 |
+
past_key_values=outputs.past_key_values,
|
| 527 |
+
hidden_states=outputs.hidden_states,
|
| 528 |
+
attentions=outputs.attentions,
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
class Ministral3ForTokenClassification(GenericForTokenClassification, Ministral3PreTrainedModel):
|
| 533 |
+
pass
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
class Ministral3ForSequenceClassification(GenericForSequenceClassification, Ministral3PreTrainedModel):
|
| 537 |
+
pass
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
class Ministral3ForQuestionAnswering(GenericForQuestionAnswering, Ministral3PreTrainedModel):
|
| 541 |
+
pass
|
| 542 |
+
|
| 543 |
+
|
| 544 |
__all__ = [
|
| 545 |
+
"Ministral3ForCausalLM",
|
| 546 |
+
"Ministral3ForQuestionAnswering",
|
| 547 |
"Ministral3Model",
|
| 548 |
"Ministral3PreTrainedModel",
|
| 549 |
+
"Ministral3ForSequenceClassification",
|
| 550 |
+
"Ministral3ForTokenClassification",
|
| 551 |
]
|
modeling_ministral_dlm.py
ADDED
|
@@ -0,0 +1,1860 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Callable, Optional, Tuple, Union
|
| 4 |
+
import random
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import json
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch import nn
|
| 13 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
|
| 14 |
+
from transformers.utils import ModelOutput
|
| 15 |
+
|
| 16 |
+
from torch.nn.attention.flex_attention import BlockMask, flex_attention, create_block_mask, or_masks
|
| 17 |
+
|
| 18 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 19 |
+
|
| 20 |
+
from transformers.processing_utils import Unpack
|
| 21 |
+
|
| 22 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 23 |
+
|
| 24 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 25 |
+
|
| 26 |
+
from transformers.generation import GenerationMixin
|
| 27 |
+
|
| 28 |
+
import math
|
| 29 |
+
|
| 30 |
+
from .chat_utils import generate_with_prefix_cache_block_diff
|
| 31 |
+
from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
|
| 32 |
+
from .configuration_ministral_dlm import MinistralDLMConfig
|
| 33 |
+
|
| 34 |
+
__all__ = ["MinistralDiffEncoderModel", "MinistralFlexAttention"]
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class MinistralDiffOutputWithPast(ModelOutput):
|
| 38 |
+
loss: torch.FloatTensor | None = None
|
| 39 |
+
logits: torch.FloatTensor | None = None
|
| 40 |
+
causal_logits: torch.FloatTensor | None = None
|
| 41 |
+
past_key_values: Cache | None = None
|
| 42 |
+
hidden_states: tuple[torch.FloatTensor, ...] | None = None
|
| 43 |
+
attentions: tuple[torch.FloatTensor, ...] | None = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# @torch.compile(dynamic=True, mode="reduce-overhead")
|
| 47 |
+
# @torch.compile(mode="default")
|
| 48 |
+
# @torch.compile(fullgraph=True, mode="reduce-overhead", dynamic=False)
|
| 49 |
+
@torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs", dynamic=False)
|
| 50 |
+
def fused_flex_attention(q, k, v, block_mask=None):
|
| 51 |
+
return flex_attention(q, k, v, block_mask=block_mask)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _crop_dynamic_cache(past_key_values: DynamicCache, max_length: int):
|
| 55 |
+
"""Crop a DynamicCache to max_length, compatible with both old and new transformers."""
|
| 56 |
+
if hasattr(past_key_values, 'crop'):
|
| 57 |
+
past_key_values.crop(max_length)
|
| 58 |
+
else:
|
| 59 |
+
for layer_idx in range(len(past_key_values)):
|
| 60 |
+
past_key_values.key_cache[layer_idx] = past_key_values.key_cache[layer_idx][:, :, :max_length]
|
| 61 |
+
past_key_values.value_cache[layer_idx] = past_key_values.value_cache[layer_idx][:, :, :max_length]
|
| 62 |
+
past_key_values._seen_tokens = max_length
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _extract_draft_kv_cache(past_key_values: DynamicCache, clean_len: int, block_length: int):
|
| 66 |
+
"""After quadratic decoding, extract only draft tokens (first of each block) from cache."""
|
| 67 |
+
for layer_idx in range(len(past_key_values)):
|
| 68 |
+
if hasattr(past_key_values, 'layers'):
|
| 69 |
+
layer_cache = past_key_values.layers[layer_idx]
|
| 70 |
+
k, v = layer_cache.keys, layer_cache.values
|
| 71 |
+
else:
|
| 72 |
+
k = past_key_values.key_cache[layer_idx]
|
| 73 |
+
v = past_key_values.value_cache[layer_idx]
|
| 74 |
+
|
| 75 |
+
clean_k, draft_k = k[:, :, :clean_len], k[:, :, clean_len::block_length + 1]
|
| 76 |
+
clean_v, draft_v = v[:, :, :clean_len], v[:, :, clean_len::block_length + 1]
|
| 77 |
+
new_k = torch.cat([clean_k, draft_k], dim=2)
|
| 78 |
+
new_v = torch.cat([clean_v, draft_v], dim=2)
|
| 79 |
+
|
| 80 |
+
if hasattr(past_key_values, 'layers'):
|
| 81 |
+
layer_cache.keys = new_k
|
| 82 |
+
layer_cache.values = new_v
|
| 83 |
+
else:
|
| 84 |
+
past_key_values.key_cache[layer_idx] = new_k
|
| 85 |
+
past_key_values.value_cache[layer_idx] = new_v
|
| 86 |
+
|
| 87 |
+
past_key_values._seen_tokens = clean_len + block_length
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# with reference to https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
|
| 91 |
+
class MinistralFlexAttention(Ministral3Attention):
|
| 92 |
+
def __init__(self, *args, **kwargs):
|
| 93 |
+
super().__init__(*args, **kwargs)
|
| 94 |
+
|
| 95 |
+
self.max_seq_length = getattr(self.config, 'max_seq_length', 4096)
|
| 96 |
+
self.block_size_orig = self.config.block_size
|
| 97 |
+
|
| 98 |
+
if self.config.dlm_paradigm == 'bidirectional':
|
| 99 |
+
self.bidirectional_mask = self.compute_block_mask(mode='bidirectional')
|
| 100 |
+
elif self.config.dlm_paradigm == 'autoregressive':
|
| 101 |
+
self.autoregressive_mask = self.compute_block_mask(mode='autoregressive')
|
| 102 |
+
elif self.config.dlm_paradigm == 'block_diff':
|
| 103 |
+
self.block_diff_mask = None
|
| 104 |
+
elif self.config.dlm_paradigm == 'sbd_block_diff':
|
| 105 |
+
self.sbd_block_diff_mask = None
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError(f"Unknown attention mode: {self.config.dlm_paradigm}")
|
| 108 |
+
|
| 109 |
+
self.block_size = self.block_size_orig
|
| 110 |
+
self.mode = self.config.dlm_paradigm
|
| 111 |
+
self._quadratic_block_mask = {}
|
| 112 |
+
|
| 113 |
+
import torch._dynamo.config as dcfg
|
| 114 |
+
dcfg.cache_size_limit = 512
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _get_sbd_inference_quadratic_decoding_block_mask(self, block_length: int):
|
| 118 |
+
if block_length not in self._quadratic_block_mask:
|
| 119 |
+
draft_len = block_length * (block_length + 1)
|
| 120 |
+
|
| 121 |
+
def quadratic(b, h, q_idx, kv_idx):
|
| 122 |
+
first_clean = torch.logical_and(
|
| 123 |
+
kv_idx % (block_length + 1) == 0,
|
| 124 |
+
kv_idx < draft_len,
|
| 125 |
+
)
|
| 126 |
+
first_clean = torch.logical_and(first_clean, q_idx >= kv_idx)
|
| 127 |
+
block_q = q_idx // (block_length + 1)
|
| 128 |
+
block_kv = kv_idx // (block_length + 1)
|
| 129 |
+
same_block = torch.logical_and(block_q == block_kv, q_idx < draft_len)
|
| 130 |
+
same_block_except_first = torch.logical_and(
|
| 131 |
+
same_block,
|
| 132 |
+
q_idx % (block_length + 1) != 0,
|
| 133 |
+
)
|
| 134 |
+
draft_part = torch.logical_or(first_clean, same_block_except_first)
|
| 135 |
+
clean_part = kv_idx >= draft_len
|
| 136 |
+
return torch.logical_or(draft_part, clean_part)
|
| 137 |
+
|
| 138 |
+
block_mask = create_block_mask(
|
| 139 |
+
quadratic,
|
| 140 |
+
B=None,
|
| 141 |
+
H=None,
|
| 142 |
+
Q_LEN=draft_len,
|
| 143 |
+
KV_LEN=draft_len + self.config.max_position_embeddings,
|
| 144 |
+
device="cuda",
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
self._quadratic_block_mask[block_length] = block_mask
|
| 148 |
+
|
| 149 |
+
return self._quadratic_block_mask[block_length]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def set_attention_mode(self, mode, block_size=None):
|
| 153 |
+
self.mode = mode
|
| 154 |
+
self.block_size = block_size
|
| 155 |
+
|
| 156 |
+
def compute_block_mask(self, mode, q_len=None, block_size=None):
|
| 157 |
+
|
| 158 |
+
def bidirectional_mask(b, h, q, kv):
|
| 159 |
+
return (q >= kv) | (q < kv)
|
| 160 |
+
|
| 161 |
+
def autoregressive_mask(b, h, q, kv):
|
| 162 |
+
return (q >= kv)
|
| 163 |
+
|
| 164 |
+
def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
|
| 165 |
+
x0_flag_q = (q_idx >= n)
|
| 166 |
+
x0_flag_kv = (kv_idx >= n)
|
| 167 |
+
|
| 168 |
+
# Compute block indices
|
| 169 |
+
block_q = torch.where(x0_flag_q == 1,
|
| 170 |
+
(q_idx - n) // block_size,
|
| 171 |
+
q_idx // block_size)
|
| 172 |
+
block_kv = torch.where(x0_flag_kv == 1,
|
| 173 |
+
(kv_idx - n) // block_size,
|
| 174 |
+
kv_idx // block_size)
|
| 175 |
+
|
| 176 |
+
# **1. Block Diagonal Mask (M_BD) **
|
| 177 |
+
block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
|
| 178 |
+
|
| 179 |
+
# **2. Offset Block-Causal Mask (M_OBC) **
|
| 180 |
+
offset_block_causal = (
|
| 181 |
+
(block_q > block_kv)
|
| 182 |
+
& (x0_flag_kv == 1)
|
| 183 |
+
& (x0_flag_q == 0)
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# **3. Block-Causal Mask (M_BC) **
|
| 187 |
+
block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
|
| 188 |
+
|
| 189 |
+
# **4. Combine Masks **
|
| 190 |
+
return block_diagonal | offset_block_causal | block_causal
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def sbd_block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
|
| 194 |
+
x0_flag_q = (q_idx >= n)
|
| 195 |
+
x0_flag_kv = (kv_idx >= n)
|
| 196 |
+
|
| 197 |
+
# Compute block indices
|
| 198 |
+
block_q = torch.where(x0_flag_q == 1,
|
| 199 |
+
(q_idx - n) // block_size,
|
| 200 |
+
q_idx // block_size)
|
| 201 |
+
block_kv = torch.where(x0_flag_kv == 1,
|
| 202 |
+
(kv_idx - n) // block_size,
|
| 203 |
+
kv_idx // block_size)
|
| 204 |
+
|
| 205 |
+
# **1. Block Diagonal Mask (M_BD) **
|
| 206 |
+
block_diagonal = (block_q == block_kv) & (x0_flag_kv == 0) & (x0_flag_q == 0)
|
| 207 |
+
|
| 208 |
+
# **2. Offset Block-Causal Mask (M_OBC) **
|
| 209 |
+
offset_block_causal = (
|
| 210 |
+
(block_q > block_kv)
|
| 211 |
+
& (x0_flag_kv == 1)
|
| 212 |
+
& (x0_flag_q == 0)
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# **3. Fully Causal Mask (M_BC) **
|
| 216 |
+
fully_causal = (q_idx >= kv_idx) & (x0_flag_kv == 1) & (x0_flag_q == 1)
|
| 217 |
+
|
| 218 |
+
# **4. Combine Masks **
|
| 219 |
+
return block_diagonal | offset_block_causal | fully_causal
|
| 220 |
+
|
| 221 |
+
if mode == 'bidirectional':
|
| 222 |
+
attn_mask = bidirectional_mask
|
| 223 |
+
elif mode == 'autoregressive':
|
| 224 |
+
attn_mask = autoregressive_mask
|
| 225 |
+
elif mode == 'block_diff':
|
| 226 |
+
assert block_size is not None
|
| 227 |
+
attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, self.max_seq_length)
|
| 228 |
+
elif mode == 'sbd_block_diff':
|
| 229 |
+
assert block_size is not None
|
| 230 |
+
attn_mask = lambda b, h, q, kv: sbd_block_diff_mask(block_size, b, h, q, kv, self.max_seq_length)
|
| 231 |
+
else:
|
| 232 |
+
raise ValueError(f"Unknown attention mode: {mode}")
|
| 233 |
+
|
| 234 |
+
if q_len is not None:
|
| 235 |
+
Q_LEN = q_len
|
| 236 |
+
else:
|
| 237 |
+
if mode in ['block_diff', 'sbd_block_diff']:
|
| 238 |
+
Q_LEN = self.max_seq_length * 2
|
| 239 |
+
else:
|
| 240 |
+
Q_LEN = self.max_seq_length
|
| 241 |
+
|
| 242 |
+
block_mask = create_block_mask(
|
| 243 |
+
attn_mask, B=None, H=None, Q_LEN=Q_LEN, KV_LEN=Q_LEN
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
return block_mask
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def forward(
|
| 250 |
+
self,
|
| 251 |
+
hidden_states: torch.Tensor,
|
| 252 |
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
| 253 |
+
attention_mask: Optional[torch.Tensor],
|
| 254 |
+
past_key_values: Optional[Cache] = None,
|
| 255 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 256 |
+
is_training: bool = True,
|
| 257 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 258 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 259 |
+
bsz, q_len, _ = hidden_states.size()
|
| 260 |
+
input_shape = hidden_states.shape[:-1]
|
| 261 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 262 |
+
|
| 263 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 264 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 265 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 266 |
+
|
| 267 |
+
cos, sin = position_embeddings
|
| 268 |
+
|
| 269 |
+
if self.mode in ['block_diff', 'sbd_block_diff'] and is_training:
|
| 270 |
+
# Split query and key states in half along sequence length dimension
|
| 271 |
+
q1, q2 = query_states.chunk(2, dim=2)
|
| 272 |
+
k1, k2 = key_states.chunk(2, dim=2)
|
| 273 |
+
|
| 274 |
+
# Apply RoPE independently to each half
|
| 275 |
+
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
| 276 |
+
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
| 277 |
+
|
| 278 |
+
# Recombine the halves
|
| 279 |
+
query_states = torch.cat([q1, q2], dim=2)
|
| 280 |
+
key_states = torch.cat([k1, k2], dim=2)
|
| 281 |
+
else:
|
| 282 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 283 |
+
|
| 284 |
+
query_states = query_states * _get_llama_4_attn_scale(
|
| 285 |
+
cache_position,
|
| 286 |
+
self.config.rope_parameters.get("llama_4_scaling_beta"),
|
| 287 |
+
self.config.rope_parameters.get("original_max_position_embeddings"),
|
| 288 |
+
).to(query_states.dtype)
|
| 289 |
+
|
| 290 |
+
if past_key_values is not None:
|
| 291 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 292 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 293 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 294 |
+
|
| 295 |
+
self_spec_inference_mode = getattr(self.config, "self_spec_inference_mode", None)
|
| 296 |
+
if self_spec_inference_mode is not None:
|
| 297 |
+
if self_spec_inference_mode == "quadratic":
|
| 298 |
+
block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
|
| 299 |
+
if block_length is None:
|
| 300 |
+
raise ValueError("SBD quadratic decoding requires block_length in config.")
|
| 301 |
+
if past_key_values is not None:
|
| 302 |
+
seq_len = key_states.shape[2]
|
| 303 |
+
draft_len = block_length * (block_length + 1)
|
| 304 |
+
|
| 305 |
+
clean_keys = key_states[:, :, :-draft_len]
|
| 306 |
+
draft_keys = key_states[:, :, -draft_len:]
|
| 307 |
+
clean_values = value_states[:, :, :-draft_len]
|
| 308 |
+
draft_values = value_states[:, :, -draft_len:]
|
| 309 |
+
key_states = torch.cat([draft_keys, clean_keys], dim=2)
|
| 310 |
+
value_states = torch.cat([draft_values, clean_values], dim=2)
|
| 311 |
+
|
| 312 |
+
block_mask: BlockMask = self._get_sbd_inference_quadratic_decoding_block_mask(
|
| 313 |
+
block_length=block_length
|
| 314 |
+
)
|
| 315 |
+
block_mask.seq_lengths = (draft_len, seq_len)
|
| 316 |
+
else:
|
| 317 |
+
seq_len = query_states.shape[2]
|
| 318 |
+
draft_len = block_length * (block_length + 1)
|
| 319 |
+
clean_len = seq_len - draft_len
|
| 320 |
+
|
| 321 |
+
def _causal_mask(b, h, q_idx, kv_idx):
|
| 322 |
+
return torch.logical_and(q_idx >= kv_idx, q_idx < clean_len)
|
| 323 |
+
|
| 324 |
+
def _draft2clean_mask(b, h, q_idx, kv_idx):
|
| 325 |
+
full_clean = torch.logical_and(q_idx >= clean_len, kv_idx <= clean_len)
|
| 326 |
+
first_clean = torch.logical_and(
|
| 327 |
+
q_idx >= clean_len, (kv_idx - clean_len) % (block_length + 1) == 0
|
| 328 |
+
)
|
| 329 |
+
first_clean = torch.logical_and(first_clean, q_idx >= kv_idx)
|
| 330 |
+
return torch.logical_or(full_clean, first_clean)
|
| 331 |
+
|
| 332 |
+
def _draft_mask(b, h, q_idx, kv_idx):
|
| 333 |
+
block_q = (q_idx - clean_len) // (block_length + 1)
|
| 334 |
+
block_kv = (kv_idx - clean_len) // (block_length + 1)
|
| 335 |
+
quadrant = torch.logical_and(q_idx >= clean_len, kv_idx >= clean_len)
|
| 336 |
+
same_block = torch.logical_and(block_q == block_kv, quadrant)
|
| 337 |
+
same_block_except_first = torch.logical_and(
|
| 338 |
+
same_block,
|
| 339 |
+
(q_idx - clean_len) % (block_length + 1) != 0,
|
| 340 |
+
)
|
| 341 |
+
return torch.logical_and(block_q == block_kv, same_block_except_first)
|
| 342 |
+
|
| 343 |
+
mask = or_masks(_causal_mask, _draft2clean_mask)
|
| 344 |
+
mask = or_masks(mask, _draft_mask)
|
| 345 |
+
|
| 346 |
+
block_mask = create_block_mask(
|
| 347 |
+
mask, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 351 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 352 |
+
attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
|
| 353 |
+
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
|
| 354 |
+
attn_output = self.o_proj(attn_output)
|
| 355 |
+
return attn_output, None
|
| 356 |
+
|
| 357 |
+
elif self_spec_inference_mode == "default":
|
| 358 |
+
block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
|
| 359 |
+
if block_length is None:
|
| 360 |
+
raise ValueError("SBD default decoding requires block_length in config.")
|
| 361 |
+
seq_len = query_states.shape[2]
|
| 362 |
+
prefix_len = seq_len - block_length
|
| 363 |
+
|
| 364 |
+
def _clean_q_mask(b, h, q_idx, kv_idx):
|
| 365 |
+
return torch.logical_and(q_idx >= kv_idx, q_idx < prefix_len)
|
| 366 |
+
|
| 367 |
+
def _noisy_q_mask(b, h, q_idx, kv_idx):
|
| 368 |
+
return q_idx >= prefix_len
|
| 369 |
+
|
| 370 |
+
block_mask = create_block_mask(
|
| 371 |
+
or_masks(_clean_q_mask, _noisy_q_mask),
|
| 372 |
+
B=None,
|
| 373 |
+
H=None,
|
| 374 |
+
Q_LEN=seq_len,
|
| 375 |
+
KV_LEN=seq_len,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 379 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 380 |
+
attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
|
| 381 |
+
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
|
| 382 |
+
attn_output = self.o_proj(attn_output)
|
| 383 |
+
return attn_output, None
|
| 384 |
+
|
| 385 |
+
else:
|
| 386 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 387 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 388 |
+
|
| 389 |
+
if self.mode == 'bidirectional':
|
| 390 |
+
if self.bidirectional_mask is None or q_len != self.bidirectional_mask.shape[-2]:
|
| 391 |
+
block_mask = self.compute_block_mask(mode='bidirectional', q_len=q_len)
|
| 392 |
+
else:
|
| 393 |
+
block_mask = self.bidirectional_mask
|
| 394 |
+
|
| 395 |
+
elif self.mode == 'autoregressive':
|
| 396 |
+
if self.autoregressive_mask is None or q_len != self.autoregressive_mask.shape[-2]:
|
| 397 |
+
block_mask = self.compute_block_mask(mode='autoregressive', q_len=q_len)
|
| 398 |
+
else:
|
| 399 |
+
block_mask = self.autoregressive_mask
|
| 400 |
+
|
| 401 |
+
elif self.mode == 'block_diff':
|
| 402 |
+
if self.block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.block_diff_mask.shape[-2]:
|
| 403 |
+
block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
|
| 404 |
+
else:
|
| 405 |
+
block_mask = self.block_diff_mask
|
| 406 |
+
elif self.mode == 'sbd_block_diff':
|
| 407 |
+
if self.sbd_block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.sbd_block_diff_mask.shape[-2]:
|
| 408 |
+
block_mask = self.compute_block_mask(mode='sbd_block_diff', block_size=self.block_size, q_len=q_len)
|
| 409 |
+
else:
|
| 410 |
+
block_mask = self.sbd_block_diff_mask
|
| 411 |
+
else:
|
| 412 |
+
raise ValueError(f"Unknown attention mode: {self.mode}")
|
| 413 |
+
|
| 414 |
+
attn_output = fused_flex_attention(query_states, key_states, value_states, block_mask=block_mask)
|
| 415 |
+
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
|
| 416 |
+
|
| 417 |
+
attn_output = self.o_proj(attn_output)
|
| 418 |
+
|
| 419 |
+
return attn_output, None
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
|
| 423 |
+
"""Return a Bool mask of length len(log_w) with exactly k True."""
|
| 424 |
+
g = -torch.log(-torch.log(torch.rand_like(log_w) + 1e-9) + 1e-9)
|
| 425 |
+
topk = torch.topk(log_w + g, k).indices
|
| 426 |
+
mask = torch.zeros_like(log_w, dtype=torch.bool)
|
| 427 |
+
mask[topk] = True
|
| 428 |
+
return mask
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
|
| 432 |
+
"""
|
| 433 |
+
A single model with:
|
| 434 |
+
- a bidirectional encoder + diffusion‐LM head over A
|
| 435 |
+
- a causal decoder + LM head over B, conditioned on F_A
|
| 436 |
+
"""
|
| 437 |
+
|
| 438 |
+
def __init__(self, config: MinistralDLMConfig):
|
| 439 |
+
super().__init__(config)
|
| 440 |
+
|
| 441 |
+
self.mask_token_id = config.mask_token_id
|
| 442 |
+
|
| 443 |
+
diffusion_config = copy.deepcopy(config)
|
| 444 |
+
diffusion_config.diffusion_lm = True
|
| 445 |
+
|
| 446 |
+
use_flex = getattr(config, 'enable_self_spec', False)
|
| 447 |
+
|
| 448 |
+
if config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
|
| 449 |
+
diffusion_config.attn_class = MinistralFlexAttention
|
| 450 |
+
elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
|
| 451 |
+
diffusion_config.attn_class = MinistralFlexAttention if use_flex else Ministral3Attention
|
| 452 |
+
if config.dlm_paradigm == 'autoregressive':
|
| 453 |
+
diffusion_config.diffusion_lm = False
|
| 454 |
+
else:
|
| 455 |
+
raise ValueError(f"Unsupported DLM paradigm: {config.dlm_paradigm}")
|
| 456 |
+
|
| 457 |
+
self.encoder = Ministral3Model(diffusion_config)
|
| 458 |
+
self.diffusion_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 459 |
+
self.vocab_size = config.vocab_size
|
| 460 |
+
|
| 461 |
+
self.current_iter_ratio = None
|
| 462 |
+
|
| 463 |
+
self.post_init()
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def get_input_embeddings(self):
|
| 467 |
+
return self.encoder.embed_tokens
|
| 468 |
+
|
| 469 |
+
def set_input_embeddings(self, value):
|
| 470 |
+
self.encoder.embed_tokens = value
|
| 471 |
+
|
| 472 |
+
def get_output_embeddings(self):
|
| 473 |
+
return self.diffusion_head
|
| 474 |
+
|
| 475 |
+
def set_output_embeddings(self, new_embeddings):
|
| 476 |
+
self.diffusion_head = new_embeddings
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def forward_process(self, input_ids, eps=1e-3, block_size=None, loss_mask=None):
|
| 480 |
+
b, l = input_ids.shape
|
| 481 |
+
device = input_ids.device
|
| 482 |
+
|
| 483 |
+
if self.config.dp_varying_mask_ratio:
|
| 484 |
+
# Enable different random seeds for each DP rank during sampling
|
| 485 |
+
import torch.distributed as dist
|
| 486 |
+
dp_rank = 0
|
| 487 |
+
if dist.is_initialized():
|
| 488 |
+
try:
|
| 489 |
+
dp_rank = dist.get_rank()
|
| 490 |
+
except Exception:
|
| 491 |
+
dp_rank = 0
|
| 492 |
+
# Use a local generator to avoid affecting global RNG state
|
| 493 |
+
generator = torch.Generator(device=device)
|
| 494 |
+
generator.manual_seed(torch.seed() + dp_rank)
|
| 495 |
+
else:
|
| 496 |
+
generator = None
|
| 497 |
+
|
| 498 |
+
if self.config.adaptive_mask_rate:
|
| 499 |
+
assert block_size is not None
|
| 500 |
+
|
| 501 |
+
# --- simple linear window mapping ---
|
| 502 |
+
bs_min = getattr(self.config, "t_bs_min", 16)
|
| 503 |
+
bs_max = getattr(self.config, "t_bs_max", 128)
|
| 504 |
+
w = getattr(self.config, "t_window_width", 0.6) # fixed width
|
| 505 |
+
|
| 506 |
+
# fraction in [0,1] (unclamped first)
|
| 507 |
+
frac = (float(block_size) - float(bs_min)) / max(1.0, float(bs_max - bs_min))
|
| 508 |
+
# upper bound decreases linearly from 1.0 -> 0.5
|
| 509 |
+
u_max = 1.0 - w * frac
|
| 510 |
+
# clamp to [0.6, 1.0] to handle bs outside [bs_min, bs_max]
|
| 511 |
+
u_max = max(0.6, min(1.0, u_max))
|
| 512 |
+
u_min = u_max - w # ensures width = w
|
| 513 |
+
|
| 514 |
+
# sample t ~ Uniform(u_min, u_max)
|
| 515 |
+
t = u_min + (u_max - u_min) * torch.rand(b, device=device, generator=generator)
|
| 516 |
+
else:
|
| 517 |
+
t = torch.rand(b, device=device, generator=generator)
|
| 518 |
+
|
| 519 |
+
p_mask = (1 - eps) * t + eps # shape: (b,)
|
| 520 |
+
p_mask = p_mask[:, None].expand(-1, l) # shape: (b, l)
|
| 521 |
+
|
| 522 |
+
masked_indices = torch.rand((b, l), device=device) < p_mask
|
| 523 |
+
|
| 524 |
+
if loss_mask is not None:
|
| 525 |
+
masked_indices[loss_mask == 0] = 0
|
| 526 |
+
|
| 527 |
+
noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
|
| 528 |
+
|
| 529 |
+
return noisy_batch, masked_indices, p_mask
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def forward_process_exp(
|
| 533 |
+
self,
|
| 534 |
+
input_ids: torch.Tensor,
|
| 535 |
+
eps: float = 1e-3,
|
| 536 |
+
block_size: int | None = None,
|
| 537 |
+
half_life_ratio: float = 0.25, # λ = ln 2 / (half_life_ratio·L)
|
| 538 |
+
loss_mask: Optional[torch.Tensor] = None,
|
| 539 |
+
):
|
| 540 |
+
"""
|
| 541 |
+
Two-stage corruption with optional per-block sampling.
|
| 542 |
+
• Stage 1: m ~ U(eps, 1) → k = round(m · len) (exact budget).
|
| 543 |
+
• Stage 2: sample exactly k positions with weights
|
| 544 |
+
w_i(m) = exp[ λ · (1−m) · i ] (late-heavy when m→0,
|
| 545 |
+
uniform when m→1).
|
| 546 |
+
If `block_size` is given, the procedure is run *independently*
|
| 547 |
+
inside each contiguous block of that length (last block may be shorter).
|
| 548 |
+
When block_size is provided, m is sampled per-block and p_mask is per-block.
|
| 549 |
+
Args
|
| 550 |
+
----
|
| 551 |
+
input_ids : (B, L) LongTensor
|
| 552 |
+
eps : minimum corruption ratio
|
| 553 |
+
block_size: if not None, operate block-wise with per-block m sampling
|
| 554 |
+
half_life_ratio : controls steepness when m→0
|
| 555 |
+
"""
|
| 556 |
+
B, L = input_ids.shape
|
| 557 |
+
device = input_ids.device
|
| 558 |
+
dtype = torch.float32
|
| 559 |
+
|
| 560 |
+
masked_indices = torch.zeros((B, L), dtype=torch.bool, device=device)
|
| 561 |
+
p_mask = torch.zeros((B, L), dtype=dtype, device=device)
|
| 562 |
+
|
| 563 |
+
# ---------- Stage 1 & 2: whole-sentence or block-wise -------------------
|
| 564 |
+
for b in range(B):
|
| 565 |
+
if block_size is None:
|
| 566 |
+
# ---------- Per-batch sampling (original behavior) ----------
|
| 567 |
+
m = eps + (1.0 - eps) * torch.rand(1, device=device).item() # scalar
|
| 568 |
+
k_tot = int(round(m * L))
|
| 569 |
+
k_tot = max(1, min(k_tot, L)) # clamp to [1, L]
|
| 570 |
+
|
| 571 |
+
# Fill p_mask for this batch
|
| 572 |
+
p_mask[b, :] = m
|
| 573 |
+
|
| 574 |
+
slope = 1.0 - m # ∈ [0,1]; 0 ⇒ uniform, 1 ⇒ late-heavy
|
| 575 |
+
|
| 576 |
+
# ------- single pool over the whole sentence -------------
|
| 577 |
+
lam_base = math.log(2.0) / (half_life_ratio * L) # base decay rate (λ when slope=1)
|
| 578 |
+
|
| 579 |
+
pos = torch.arange(L, device=device, dtype=dtype)
|
| 580 |
+
log_w = (lam_base * slope * pos).clone()
|
| 581 |
+
|
| 582 |
+
masked_indices[b] = gumbel_topk(log_w, k_tot)
|
| 583 |
+
|
| 584 |
+
else:
|
| 585 |
+
# ---------- Per-block sampling ----------
|
| 586 |
+
num_blocks = math.ceil(L / block_size)
|
| 587 |
+
lam_base = math.log(2.0) / (half_life_ratio * block_size) # base decay rate (λ when slope=1)
|
| 588 |
+
|
| 589 |
+
for blk in range(num_blocks):
|
| 590 |
+
start = blk * block_size
|
| 591 |
+
end = min((blk + 1) * block_size, L)
|
| 592 |
+
blk_len = end - start
|
| 593 |
+
|
| 594 |
+
# Sample m per block
|
| 595 |
+
m_blk = eps + (1.0 - eps) * torch.rand(1, device=device).item()
|
| 596 |
+
|
| 597 |
+
# Fill p_mask for this block
|
| 598 |
+
p_mask[b, start:end] = m_blk
|
| 599 |
+
|
| 600 |
+
# per-block budget
|
| 601 |
+
k_blk = int(round(m_blk * blk_len))
|
| 602 |
+
k_blk = max(0, min(k_blk, blk_len))
|
| 603 |
+
if k_blk == 0:
|
| 604 |
+
continue
|
| 605 |
+
|
| 606 |
+
slope = 1.0 - m_blk # ∈ [0,1]; 0 ⇒ uniform, 1 ⇒ late-heavy
|
| 607 |
+
|
| 608 |
+
pos = torch.arange(blk_len, device=device, dtype=dtype)
|
| 609 |
+
log_w = lam_base * slope * pos
|
| 610 |
+
|
| 611 |
+
blk_mask = gumbel_topk(log_w, k_blk)
|
| 612 |
+
masked_indices[b, start:end] = blk_mask
|
| 613 |
+
|
| 614 |
+
if loss_mask is not None:
|
| 615 |
+
masked_indices[loss_mask == 0] = 0
|
| 616 |
+
|
| 617 |
+
noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
|
| 618 |
+
return noisy_batch, masked_indices, p_mask
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
def forward(
|
| 622 |
+
self,
|
| 623 |
+
input_ids: torch.LongTensor,
|
| 624 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 625 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 626 |
+
labels: Optional[torch.LongTensor] = None,
|
| 627 |
+
split_len: Optional[int] = None,
|
| 628 |
+
past_key_values: Optional[Cache] = None,
|
| 629 |
+
block_size: Optional[int] = None,
|
| 630 |
+
block_diff_ppl: bool = False,
|
| 631 |
+
eps: float = 1e-3,
|
| 632 |
+
is_teacher: bool = False,
|
| 633 |
+
masked_indices: Optional[torch.Tensor] = None,
|
| 634 |
+
p_mask: Optional[torch.Tensor] = None,
|
| 635 |
+
teacher_logits: Optional[torch.Tensor] = None,
|
| 636 |
+
masked_indices_teacher: Optional[torch.Tensor] = None,
|
| 637 |
+
loss_mask: Optional[torch.Tensor] = None,
|
| 638 |
+
ce_loss_weight: float = 1.0,
|
| 639 |
+
output_last_hidden_states_only: bool = False,
|
| 640 |
+
skip_loss: bool = False,
|
| 641 |
+
**kwargs,
|
| 642 |
+
) -> CausalLMOutputWithPast:
|
| 643 |
+
|
| 644 |
+
batch_size, seq_len = input_ids.shape
|
| 645 |
+
|
| 646 |
+
if self.config.dlm_paradigm == 'bidirectional' or self.config.dlm_paradigm == 'autoregressive':
|
| 647 |
+
if labels is not None and torch.rand(1) < self.config.random_length_prob:
|
| 648 |
+
random_length = torch.randint(2, input_ids.shape[1] + 1, (1,))
|
| 649 |
+
input_ids = input_ids[:, :random_length]
|
| 650 |
+
labels = labels[:, :random_length]
|
| 651 |
+
|
| 652 |
+
if attention_mask is not None:
|
| 653 |
+
attention_mask = attention_mask[:, :random_length]
|
| 654 |
+
if position_ids is not None:
|
| 655 |
+
position_ids = position_ids[:, :random_length]
|
| 656 |
+
if loss_mask is not None:
|
| 657 |
+
loss_mask = loss_mask[:, :random_length]
|
| 658 |
+
|
| 659 |
+
elif self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
|
| 660 |
+
if labels is not None and block_size is None:
|
| 661 |
+
if torch.rand(1) < self.config.random_length_prob:
|
| 662 |
+
block_size = torch.randint(1, 8, (1,)).item() * 4 ## [4, 32] divisible by 4
|
| 663 |
+
else:
|
| 664 |
+
block_size = self.config.block_size
|
| 665 |
+
|
| 666 |
+
else:
|
| 667 |
+
raise ValueError(f"Unknown dLM paradigm: {self.config.dlm_paradigm}")
|
| 668 |
+
|
| 669 |
+
if labels is not None and self.config.dlm_paradigm != 'autoregressive':
|
| 670 |
+
if masked_indices is not None:
|
| 671 |
+
# assert p_mask is not None
|
| 672 |
+
|
| 673 |
+
if loss_mask is not None:
|
| 674 |
+
masked_indices[loss_mask == 0] = 0
|
| 675 |
+
|
| 676 |
+
noisy_inputs = torch.where(masked_indices, self.mask_token_id, input_ids)
|
| 677 |
+
|
| 678 |
+
else:
|
| 679 |
+
if self.config.tok_mask_half_life_ratio is not None:
|
| 680 |
+
noisy_inputs, masked_indices, p_mask = self.forward_process_exp(input_ids, eps=eps, block_size=block_size, half_life_ratio=self.config.tok_mask_half_life_ratio, loss_mask=loss_mask)
|
| 681 |
+
else:
|
| 682 |
+
noisy_inputs, masked_indices, p_mask = self.forward_process(input_ids, eps=eps, block_size=block_size, loss_mask=loss_mask)
|
| 683 |
+
|
| 684 |
+
else:
|
| 685 |
+
noisy_inputs = input_ids
|
| 686 |
+
masked_indices = None
|
| 687 |
+
p_mask = None
|
| 688 |
+
|
| 689 |
+
if self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
|
| 690 |
+
for layer in self.encoder.layers:
|
| 691 |
+
if hasattr(layer.self_attn, 'set_attention_mode'):
|
| 692 |
+
layer.self_attn.set_attention_mode(self.config.dlm_paradigm, block_size=block_size)
|
| 693 |
+
|
| 694 |
+
input_ids_len = noisy_inputs.shape[1]
|
| 695 |
+
if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
|
| 696 |
+
if position_ids is None:
|
| 697 |
+
position_ids = torch.arange(input_ids_len, device=noisy_inputs.device).unsqueeze(0)
|
| 698 |
+
noisy_inputs = torch.cat([noisy_inputs, input_ids], dim=1)
|
| 699 |
+
|
| 700 |
+
if block_diff_ppl:
|
| 701 |
+
if position_ids is None:
|
| 702 |
+
position_ids = torch.arange(input_ids_len // 2, device=noisy_inputs.device).unsqueeze(0)
|
| 703 |
+
|
| 704 |
+
enc_out = self.encoder(
|
| 705 |
+
past_key_values=past_key_values,
|
| 706 |
+
input_ids=noisy_inputs,
|
| 707 |
+
attention_mask=attention_mask,
|
| 708 |
+
position_ids=position_ids,
|
| 709 |
+
is_training=(labels is not None) or (block_diff_ppl),
|
| 710 |
+
**kwargs,
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
if output_last_hidden_states_only:
|
| 714 |
+
return BaseModelOutput(last_hidden_state=enc_out.last_hidden_state)
|
| 715 |
+
|
| 716 |
+
logits = self.diffusion_head(enc_out.last_hidden_state) # (batch, len_B, vocab)
|
| 717 |
+
causal_logits = None
|
| 718 |
+
|
| 719 |
+
if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
|
| 720 |
+
if self.config.dlm_paradigm == 'sbd_block_diff':
|
| 721 |
+
causal_logits = logits[:, input_ids_len:]
|
| 722 |
+
else:
|
| 723 |
+
causal_logits = None
|
| 724 |
+
|
| 725 |
+
logits = logits[:, :input_ids_len]
|
| 726 |
+
|
| 727 |
+
loss = None
|
| 728 |
+
if labels is not None and not skip_loss:
|
| 729 |
+
if self.config.dlm_paradigm == 'autoregressive':
|
| 730 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 731 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 732 |
+
|
| 733 |
+
if loss_mask is None:
|
| 734 |
+
loss_fct = CrossEntropyLoss()
|
| 735 |
+
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
| 736 |
+
shift_labels = shift_labels.view(-1)
|
| 737 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 738 |
+
|
| 739 |
+
else:
|
| 740 |
+
loss_mask = loss_mask[..., 1:].contiguous()
|
| 741 |
+
|
| 742 |
+
loss_fct = CrossEntropyLoss(reduction='none')
|
| 743 |
+
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
| 744 |
+
shift_labels = shift_labels.view(-1)
|
| 745 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 746 |
+
|
| 747 |
+
token_losses = loss_fct(shift_logits, shift_labels)
|
| 748 |
+
|
| 749 |
+
flat_loss_mask = loss_mask.reshape(-1)
|
| 750 |
+
loss = token_losses[flat_loss_mask == 1].sum() / flat_loss_mask.sum()
|
| 751 |
+
|
| 752 |
+
else:
|
| 753 |
+
# Handle DREAM vs LLADA style losses
|
| 754 |
+
if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
|
| 755 |
+
logits = logits[..., :-1, :].contiguous()
|
| 756 |
+
labels = labels[..., 1:].contiguous()
|
| 757 |
+
masked_indices = masked_indices[:, 1:]
|
| 758 |
+
p_mask = p_mask[:, 1:]
|
| 759 |
+
|
| 760 |
+
if self.config.ada_perm_ratio_per_block is not None:
|
| 761 |
+
# Only compute loss for the top ada_perm_ratio_per_block tokens by confidence within each block
|
| 762 |
+
block_size = self.config.block_size
|
| 763 |
+
batch_size, seq_len = masked_indices.shape
|
| 764 |
+
num_blocks = seq_len // block_size
|
| 765 |
+
|
| 766 |
+
# Get the max logit (confidence) for each position
|
| 767 |
+
confidence = logits.max(dim=-1).values.detach() # (batch_size, seq_len)
|
| 768 |
+
|
| 769 |
+
# Create a mask for tokens to include in loss
|
| 770 |
+
selected_mask = torch.zeros_like(masked_indices, dtype=torch.bool)
|
| 771 |
+
|
| 772 |
+
for blk in range(num_blocks):
|
| 773 |
+
start = blk * block_size
|
| 774 |
+
end = min((blk + 1) * block_size, seq_len)
|
| 775 |
+
|
| 776 |
+
# Get masked indices within this block
|
| 777 |
+
block_masked = masked_indices[:, start:end] # (batch_size, block_len)
|
| 778 |
+
block_confidence = confidence[:, start:end] # (batch_size, block_len)
|
| 779 |
+
|
| 780 |
+
for b in range(batch_size):
|
| 781 |
+
# Get positions that are masked in this block for this batch
|
| 782 |
+
masked_positions = torch.where(block_masked[b])[0]
|
| 783 |
+
num_masked = len(masked_positions)
|
| 784 |
+
|
| 785 |
+
if num_masked > 0:
|
| 786 |
+
# Number of tokens to keep (top by confidence)
|
| 787 |
+
k = min(max(1, int(block_size * self.config.ada_perm_ratio_per_block)), num_masked)
|
| 788 |
+
|
| 789 |
+
# Get confidence values for masked positions
|
| 790 |
+
masked_confidence = block_confidence[b, masked_positions]
|
| 791 |
+
|
| 792 |
+
# Get indices of top-k confident tokens
|
| 793 |
+
_, topk_indices = torch.topk(masked_confidence, k)
|
| 794 |
+
selected_positions = masked_positions[topk_indices]
|
| 795 |
+
|
| 796 |
+
# Mark these positions in the selected mask
|
| 797 |
+
selected_mask[b, start + selected_positions] = True
|
| 798 |
+
|
| 799 |
+
# Calculate loss only for selected positions
|
| 800 |
+
token_loss = torch.nn.functional.cross_entropy(
|
| 801 |
+
logits[selected_mask],
|
| 802 |
+
labels[selected_mask],
|
| 803 |
+
reduction='none'
|
| 804 |
+
) / p_mask[selected_mask]
|
| 805 |
+
|
| 806 |
+
num_mask_tokens = selected_mask.sum()
|
| 807 |
+
|
| 808 |
+
else:
|
| 809 |
+
# Calculate token-wise cross entropy loss for masked positions in B
|
| 810 |
+
token_loss = torch.nn.functional.cross_entropy(
|
| 811 |
+
logits[masked_indices],
|
| 812 |
+
labels[masked_indices],
|
| 813 |
+
reduction='none'
|
| 814 |
+
) / p_mask[masked_indices]
|
| 815 |
+
|
| 816 |
+
num_mask_tokens = masked_indices.sum()
|
| 817 |
+
|
| 818 |
+
if self.config.global_loss_avg:
|
| 819 |
+
loss = token_loss.sum()
|
| 820 |
+
else:
|
| 821 |
+
loss = token_loss.sum() / num_mask_tokens
|
| 822 |
+
|
| 823 |
+
if self.config.ada_dlm_loss_ratio is not None:
|
| 824 |
+
assert self.current_iter_ratio is not None
|
| 825 |
+
assert self.config.dlm_loss_weight is not None
|
| 826 |
+
|
| 827 |
+
dlm_loss_weight = min(self.config.dlm_loss_weight, self.current_iter_ratio / self.config.ada_dlm_loss_ratio * self.config.dlm_loss_weight)
|
| 828 |
+
loss = dlm_loss_weight * loss
|
| 829 |
+
|
| 830 |
+
elif self.config.dlm_loss_weight is not None:
|
| 831 |
+
loss = self.config.dlm_loss_weight * loss
|
| 832 |
+
|
| 833 |
+
if self.config.dlm_paradigm == 'sbd_block_diff':
|
| 834 |
+
causal_logits = causal_logits[..., :-1, :].contiguous()
|
| 835 |
+
causal_logits = causal_logits.view(-1, causal_logits.size(-1))
|
| 836 |
+
|
| 837 |
+
if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
|
| 838 |
+
causal_labels = labels.view(-1)
|
| 839 |
+
else:
|
| 840 |
+
causal_labels = labels[..., 1:].contiguous().view(-1)
|
| 841 |
+
|
| 842 |
+
if self.config.global_loss_avg:
|
| 843 |
+
loss_fct = CrossEntropyLoss(reduction='sum')
|
| 844 |
+
ar_loss = loss_fct(causal_logits, causal_labels)
|
| 845 |
+
|
| 846 |
+
self.loss_diffusion = loss.detach().item() / num_mask_tokens
|
| 847 |
+
self.loss_ar = ar_loss.detach().item() / seq_len
|
| 848 |
+
|
| 849 |
+
loss = loss + self.config.ar_loss_weight * ar_loss
|
| 850 |
+
else:
|
| 851 |
+
loss_fct = CrossEntropyLoss()
|
| 852 |
+
ar_loss = loss_fct(causal_logits, causal_labels)
|
| 853 |
+
|
| 854 |
+
self.loss_diffusion = loss.detach().item()
|
| 855 |
+
self.loss_ar = ar_loss.detach().item()
|
| 856 |
+
|
| 857 |
+
loss = loss + self.config.ar_loss_weight * ar_loss
|
| 858 |
+
|
| 859 |
+
if self.config.global_loss_avg:
|
| 860 |
+
if self.config.dlm_paradigm == 'sbd_block_diff':
|
| 861 |
+
loss = (loss, num_mask_tokens + int(self.config.ar_loss_weight * seq_len))
|
| 862 |
+
else:
|
| 863 |
+
loss = (loss, num_mask_tokens)
|
| 864 |
+
|
| 865 |
+
return MinistralDiffOutputWithPast(
|
| 866 |
+
loss=loss if not is_teacher else logits,
|
| 867 |
+
logits=logits,
|
| 868 |
+
causal_logits=causal_logits,
|
| 869 |
+
past_key_values=enc_out.past_key_values,
|
| 870 |
+
hidden_states=None,
|
| 871 |
+
attentions=None,
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold, causal_context=True, temperature=0, eos_token_id=None, max_thinking_tokens=None, end_think_token_id=None):
|
| 876 |
+
if eos_token_id is None:
|
| 877 |
+
eos_token_id = getattr(self.config, 'eos_token_id', None)
|
| 878 |
+
|
| 879 |
+
out_ids, nfe = generate_with_prefix_cache_block_diff(
|
| 880 |
+
model=self,
|
| 881 |
+
prompt=prompt_ids,
|
| 882 |
+
gen_length=max_new_tokens,
|
| 883 |
+
steps=steps,
|
| 884 |
+
block_length=block_length,
|
| 885 |
+
remasking="low_confidence",
|
| 886 |
+
temperature=temperature,
|
| 887 |
+
mask_id=self.mask_token_id,
|
| 888 |
+
threshold=threshold,
|
| 889 |
+
shift_logits=shift_logits,
|
| 890 |
+
neg_entropy=False,
|
| 891 |
+
causal_context=causal_context,
|
| 892 |
+
eos_token_id=eos_token_id,
|
| 893 |
+
max_thinking_tokens=max_thinking_tokens,
|
| 894 |
+
end_think_token_id=end_think_token_id,
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
return out_ids, nfe
|
| 898 |
+
|
| 899 |
+
|
| 900 |
+
@torch.no_grad()
|
| 901 |
+
def sbd_inference_diffusion_quadratic(
|
| 902 |
+
self,
|
| 903 |
+
clean_input_ids: Optional[torch.Tensor],
|
| 904 |
+
draft_input_ids: torch.Tensor,
|
| 905 |
+
block_length: int,
|
| 906 |
+
draft_only: bool = False,
|
| 907 |
+
past_key_values: Optional[Cache] = None,
|
| 908 |
+
use_cache: bool = False,
|
| 909 |
+
):
|
| 910 |
+
enc_config = self.encoder.config
|
| 911 |
+
enc_config.use_sbd_objective = True
|
| 912 |
+
enc_config.block_length = block_length
|
| 913 |
+
|
| 914 |
+
if draft_only:
|
| 915 |
+
assert clean_input_ids is not None
|
| 916 |
+
|
| 917 |
+
if use_cache and past_key_values is None:
|
| 918 |
+
past_key_values = DynamicCache()
|
| 919 |
+
|
| 920 |
+
enc_config.self_spec_inference_mode = "default"
|
| 921 |
+
input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
|
| 922 |
+
outputs = self.encoder(
|
| 923 |
+
input_ids=input_ids,
|
| 924 |
+
position_ids=None,
|
| 925 |
+
past_key_values=past_key_values,
|
| 926 |
+
use_cache=use_cache,
|
| 927 |
+
is_training=False,
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
hidden_states = outputs.last_hidden_state
|
| 931 |
+
logits = self.diffusion_head(hidden_states)
|
| 932 |
+
|
| 933 |
+
past_key_values = getattr(outputs, "past_key_values", None)
|
| 934 |
+
if use_cache and past_key_values is not None:
|
| 935 |
+
_crop_dynamic_cache(past_key_values, clean_input_ids.shape[1])
|
| 936 |
+
|
| 937 |
+
return logits, past_key_values
|
| 938 |
+
else:
|
| 939 |
+
enc_config.self_spec_inference_mode = "quadratic"
|
| 940 |
+
|
| 941 |
+
draft_len = block_length * (block_length + 1)
|
| 942 |
+
draft_input_ids = torch.cat(
|
| 943 |
+
[
|
| 944 |
+
draft_input_ids.view(-1, block_length, 1),
|
| 945 |
+
torch.full(
|
| 946 |
+
(draft_input_ids.shape[0], block_length, block_length),
|
| 947 |
+
fill_value=self.config.mask_token_id,
|
| 948 |
+
device=draft_input_ids.device,
|
| 949 |
+
),
|
| 950 |
+
],
|
| 951 |
+
dim=-1,
|
| 952 |
+
).view(-1, draft_len)
|
| 953 |
+
|
| 954 |
+
if use_cache:
|
| 955 |
+
assert past_key_values is not None, (
|
| 956 |
+
"Past key values should be provided when using cache, e.g. run draft_only=True first."
|
| 957 |
+
)
|
| 958 |
+
assert clean_input_ids is None, (
|
| 959 |
+
"Clean input ids should already be in cache, thus none should be provided."
|
| 960 |
+
)
|
| 961 |
+
clean_len = past_key_values.get_seq_length()
|
| 962 |
+
input_ids = draft_input_ids
|
| 963 |
+
else:
|
| 964 |
+
clean_len = clean_input_ids.shape[1]
|
| 965 |
+
input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
|
| 966 |
+
|
| 967 |
+
per_block_position_ids = torch.arange(
|
| 968 |
+
clean_len, clean_len + block_length + 1, device=draft_input_ids.device
|
| 969 |
+
)[None,].repeat(block_length, 1)
|
| 970 |
+
per_block_position_ids += torch.arange(block_length, device=draft_input_ids.device).view(-1, 1)
|
| 971 |
+
|
| 972 |
+
if use_cache:
|
| 973 |
+
position_ids = per_block_position_ids.view(-1)[None,]
|
| 974 |
+
else:
|
| 975 |
+
clean_position_ids = torch.arange(clean_len, device=draft_input_ids.device)
|
| 976 |
+
position_ids = torch.cat([clean_position_ids, per_block_position_ids.view(-1)], dim=-1)[None,]
|
| 977 |
+
|
| 978 |
+
outputs = self.encoder(
|
| 979 |
+
input_ids=input_ids,
|
| 980 |
+
position_ids=position_ids,
|
| 981 |
+
past_key_values=past_key_values,
|
| 982 |
+
use_cache=use_cache,
|
| 983 |
+
is_training=False,
|
| 984 |
+
)
|
| 985 |
+
|
| 986 |
+
hidden_states = outputs.last_hidden_state
|
| 987 |
+
logits = self.diffusion_head(hidden_states)
|
| 988 |
+
past_key_values = getattr(outputs, "past_key_values", None)
|
| 989 |
+
|
| 990 |
+
if use_cache and past_key_values is not None:
|
| 991 |
+
_extract_draft_kv_cache(past_key_values, clean_len, block_length)
|
| 992 |
+
|
| 993 |
+
return logits, past_key_values
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
@torch.no_grad()
|
| 997 |
+
def ar_generate(
|
| 998 |
+
self,
|
| 999 |
+
prompt_ids: torch.Tensor,
|
| 1000 |
+
max_new_tokens: int = 128,
|
| 1001 |
+
temperature: float = 0.0,
|
| 1002 |
+
eos_token_id: Optional[int] = None,
|
| 1003 |
+
max_thinking_tokens: Optional[int] = None,
|
| 1004 |
+
end_think_token_id: Optional[int] = None,
|
| 1005 |
+
) -> tuple:
|
| 1006 |
+
"""Autoregressive generation calling the encoder directly (injected by build_hf_tidar_repo).
|
| 1007 |
+
|
| 1008 |
+
Bypasses MinistralDiffEncoderModel.forward() to avoid diffusion-specific
|
| 1009 |
+
code paths. Calls self.encoder (Ministral3Model) with explicit cache_position,
|
| 1010 |
+
position_ids, and use_cache so the KV cache and causal masking behave
|
| 1011 |
+
identically to MistralForCausalLM / vLLM.
|
| 1012 |
+
|
| 1013 |
+
Returns:
|
| 1014 |
+
(output_ids, nfe) where output_ids includes the prompt.
|
| 1015 |
+
"""
|
| 1016 |
+
for layer in self.encoder.layers:
|
| 1017 |
+
if hasattr(layer.self_attn, 'diffusion_lm'):
|
| 1018 |
+
layer.self_attn.diffusion_lm = False
|
| 1019 |
+
|
| 1020 |
+
if eos_token_id is None:
|
| 1021 |
+
eos_token_id = getattr(self.config, 'eos_token_id', None)
|
| 1022 |
+
|
| 1023 |
+
device = prompt_ids.device
|
| 1024 |
+
batch_size, prompt_len = prompt_ids.shape
|
| 1025 |
+
|
| 1026 |
+
past_key_values = DynamicCache()
|
| 1027 |
+
cache_position = torch.arange(prompt_len, device=device)
|
| 1028 |
+
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
|
| 1029 |
+
|
| 1030 |
+
enc_out = self.encoder(
|
| 1031 |
+
input_ids=prompt_ids,
|
| 1032 |
+
position_ids=position_ids,
|
| 1033 |
+
past_key_values=past_key_values,
|
| 1034 |
+
use_cache=True,
|
| 1035 |
+
cache_position=cache_position,
|
| 1036 |
+
)
|
| 1037 |
+
past_key_values = enc_out.past_key_values
|
| 1038 |
+
next_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
|
| 1039 |
+
|
| 1040 |
+
generated_tokens = []
|
| 1041 |
+
nfe = 0
|
| 1042 |
+
|
| 1043 |
+
for step in range(max_new_tokens):
|
| 1044 |
+
nfe += 1
|
| 1045 |
+
|
| 1046 |
+
if temperature > 0:
|
| 1047 |
+
probs = torch.softmax(next_logit / temperature, dim=-1)
|
| 1048 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 1049 |
+
else:
|
| 1050 |
+
next_token = torch.argmax(next_logit, dim=-1, keepdim=True)
|
| 1051 |
+
|
| 1052 |
+
# ---- thinking budget enforcement ----
|
| 1053 |
+
if end_think_token_id is not None and max_thinking_tokens is not None:
|
| 1054 |
+
if step >= max_thinking_tokens:
|
| 1055 |
+
if generated_tokens:
|
| 1056 |
+
gen_tensor = torch.cat(generated_tokens, dim=1)
|
| 1057 |
+
has_end_think = (gen_tensor == end_think_token_id).any(dim=1)
|
| 1058 |
+
else:
|
| 1059 |
+
has_end_think = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
| 1060 |
+
for b in range(batch_size):
|
| 1061 |
+
if not has_end_think[b]:
|
| 1062 |
+
next_token[b] = end_think_token_id
|
| 1063 |
+
|
| 1064 |
+
generated_tokens.append(next_token)
|
| 1065 |
+
|
| 1066 |
+
if eos_token_id is not None and (next_token == eos_token_id).all():
|
| 1067 |
+
break
|
| 1068 |
+
|
| 1069 |
+
if step < max_new_tokens - 1:
|
| 1070 |
+
cur_pos = prompt_len + step
|
| 1071 |
+
step_cache_pos = torch.tensor([cur_pos], device=device)
|
| 1072 |
+
step_pos_ids = step_cache_pos.unsqueeze(0).expand(batch_size, -1)
|
| 1073 |
+
|
| 1074 |
+
enc_out = self.encoder(
|
| 1075 |
+
input_ids=next_token,
|
| 1076 |
+
position_ids=step_pos_ids,
|
| 1077 |
+
past_key_values=past_key_values,
|
| 1078 |
+
use_cache=True,
|
| 1079 |
+
cache_position=step_cache_pos,
|
| 1080 |
+
)
|
| 1081 |
+
past_key_values = enc_out.past_key_values
|
| 1082 |
+
next_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
|
| 1083 |
+
|
| 1084 |
+
all_generated = torch.cat(generated_tokens, dim=1)
|
| 1085 |
+
output_ids = torch.cat([prompt_ids, all_generated], dim=1)
|
| 1086 |
+
return output_ids, nfe
|
| 1087 |
+
|
| 1088 |
+
|
| 1089 |
+
@torch.no_grad()
|
| 1090 |
+
def self_spec_generate(
|
| 1091 |
+
self,
|
| 1092 |
+
prompt_ids: torch.Tensor,
|
| 1093 |
+
max_new_tokens: int = 128,
|
| 1094 |
+
steps: int = 128,
|
| 1095 |
+
block_length: int = 16,
|
| 1096 |
+
ar_mix_weight: Optional[float] = None,
|
| 1097 |
+
temperature: float = 0.0,
|
| 1098 |
+
mask_token_id: Optional[int] = None,
|
| 1099 |
+
eos_token_id: Optional[int] = None,
|
| 1100 |
+
max_thinking_tokens: Optional[int] = None,
|
| 1101 |
+
end_think_token_id: Optional[int] = None,
|
| 1102 |
+
):
|
| 1103 |
+
self.config.use_sbd_objective = True
|
| 1104 |
+
self.config.dlm_paradigm = "sbd"
|
| 1105 |
+
|
| 1106 |
+
if prompt_ids.shape[0] != 1:
|
| 1107 |
+
raise ValueError("Self speculation quadratic decoding currently requires batch_size == 1")
|
| 1108 |
+
|
| 1109 |
+
token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
|
| 1110 |
+
if eos_token_id is None:
|
| 1111 |
+
eos_token_id = getattr(self.config, "eos_token_id", None)
|
| 1112 |
+
|
| 1113 |
+
x = torch.full(
|
| 1114 |
+
(1, prompt_ids.shape[1] + max_new_tokens + block_length * 2),
|
| 1115 |
+
token_mask_id,
|
| 1116 |
+
dtype=torch.long,
|
| 1117 |
+
device=prompt_ids.device,
|
| 1118 |
+
)
|
| 1119 |
+
x[:, : prompt_ids.shape[1]] = prompt_ids.clone()
|
| 1120 |
+
|
| 1121 |
+
if max_new_tokens % block_length != 0:
|
| 1122 |
+
raise ValueError("max_new_tokens must be divisible by block_length")
|
| 1123 |
+
num_blocks = max_new_tokens // block_length
|
| 1124 |
+
if steps % num_blocks != 0:
|
| 1125 |
+
raise ValueError("steps must be divisible by (max_new_tokens // block_length)")
|
| 1126 |
+
|
| 1127 |
+
prompt_len = prompt_ids.shape[1]
|
| 1128 |
+
nfe = 0
|
| 1129 |
+
nfe += 1
|
| 1130 |
+
logits, past_key_values = self.sbd_inference_diffusion_quadratic(
|
| 1131 |
+
clean_input_ids=x[:, :prompt_len],
|
| 1132 |
+
draft_input_ids=x[:, prompt_len : prompt_len + block_length],
|
| 1133 |
+
block_length=block_length,
|
| 1134 |
+
draft_only=True,
|
| 1135 |
+
use_cache=True,
|
| 1136 |
+
)
|
| 1137 |
+
|
| 1138 |
+
logits_proposal = logits[:, prompt_len - 1 : prompt_len + block_length]
|
| 1139 |
+
logits_proposal[:, 1] = logits_proposal[:, 0]
|
| 1140 |
+
logits_proposal = logits_proposal[:, 1:]
|
| 1141 |
+
x0_proposal = torch.argmax(logits_proposal, dim=-1)
|
| 1142 |
+
x[:, prompt_len : prompt_len + block_length] = x0_proposal
|
| 1143 |
+
|
| 1144 |
+
total_accept_token = 0
|
| 1145 |
+
while True:
|
| 1146 |
+
nfe += 1
|
| 1147 |
+
block_start = prompt_len + total_accept_token
|
| 1148 |
+
block_end = block_start + block_length
|
| 1149 |
+
draft_input_ids = x[:, block_start:block_end]
|
| 1150 |
+
|
| 1151 |
+
logits, past_key_values = self.sbd_inference_diffusion_quadratic(
|
| 1152 |
+
clean_input_ids=None,
|
| 1153 |
+
draft_input_ids=draft_input_ids,
|
| 1154 |
+
block_length=block_length,
|
| 1155 |
+
draft_only=False,
|
| 1156 |
+
past_key_values=past_key_values,
|
| 1157 |
+
use_cache=True,
|
| 1158 |
+
)
|
| 1159 |
+
|
| 1160 |
+
useful_token_logits = logits.view(1, block_length, block_length + 1, -1)
|
| 1161 |
+
if ar_mix_weight is None:
|
| 1162 |
+
useful_token_logits[:, :, 1] = useful_token_logits[:, :, 0]
|
| 1163 |
+
else:
|
| 1164 |
+
if not (0.0 <= ar_mix_weight <= 1.0):
|
| 1165 |
+
raise ValueError("ar_mix_weight must be between 0 and 1")
|
| 1166 |
+
mix_logits = useful_token_logits[:, :, 0] * ar_mix_weight + useful_token_logits[:, :, 1] * (1 - ar_mix_weight)
|
| 1167 |
+
useful_token_logits[:, :, 0] = mix_logits
|
| 1168 |
+
useful_token_logits[:, :, 1] = mix_logits
|
| 1169 |
+
|
| 1170 |
+
if temperature > 0:
|
| 1171 |
+
useful_token_logits = useful_token_logits / temperature
|
| 1172 |
+
|
| 1173 |
+
useful_token_pred = torch.argmax(useful_token_logits, dim=-1)
|
| 1174 |
+
new_draft_input_ids = useful_token_pred[:, 0, 1:]
|
| 1175 |
+
accept_cnt = 1
|
| 1176 |
+
|
| 1177 |
+
while accept_cnt < block_length:
|
| 1178 |
+
if useful_token_pred[:, accept_cnt - 1, 0].item() != draft_input_ids[:, accept_cnt].item():
|
| 1179 |
+
break
|
| 1180 |
+
new_draft_input_ids = useful_token_pred[:, accept_cnt, 1:]
|
| 1181 |
+
accept_cnt += 1
|
| 1182 |
+
|
| 1183 |
+
x[:, block_start : block_start + accept_cnt] = draft_input_ids[:, :accept_cnt]
|
| 1184 |
+
|
| 1185 |
+
# EoS early stopping: all accepted tokens are finalized left-to-right,
|
| 1186 |
+
# so if any is EoS we can truncate and return immediately.
|
| 1187 |
+
if eos_token_id is not None:
|
| 1188 |
+
accepted = x[0, block_start : block_start + accept_cnt]
|
| 1189 |
+
eos_positions = (accepted == eos_token_id).nonzero(as_tuple=True)[0]
|
| 1190 |
+
if len(eos_positions) > 0:
|
| 1191 |
+
first_eos_rel = eos_positions[0].item()
|
| 1192 |
+
total_accept_token += first_eos_rel + 1
|
| 1193 |
+
output_end = prompt_len + total_accept_token
|
| 1194 |
+
return x[:, :output_end], nfe
|
| 1195 |
+
|
| 1196 |
+
x[:, block_start + accept_cnt : block_start + accept_cnt + block_length] = new_draft_input_ids
|
| 1197 |
+
past_key_values.crop(block_start + accept_cnt)
|
| 1198 |
+
|
| 1199 |
+
# ---- thinking budget enforcement ----
|
| 1200 |
+
# Insert end_think as the first token of the next draft block,
|
| 1201 |
+
# shifting all subsequent tokens right by 1 (discarding the last).
|
| 1202 |
+
# The first draft token is always accepted unconditionally, so
|
| 1203 |
+
# end_think is guaranteed to be finalized in the next iteration
|
| 1204 |
+
# without needing to re-encode or touch the KV cache.
|
| 1205 |
+
if end_think_token_id is not None and max_thinking_tokens is not None:
|
| 1206 |
+
tokens_so_far = total_accept_token + accept_cnt
|
| 1207 |
+
if tokens_so_far > max_thinking_tokens:
|
| 1208 |
+
gen_so_far = x[0, prompt_len : prompt_len + tokens_so_far]
|
| 1209 |
+
has_end_think = (gen_so_far == end_think_token_id).any()
|
| 1210 |
+
if not has_end_think:
|
| 1211 |
+
insert_pos = block_start + accept_cnt
|
| 1212 |
+
x[0, insert_pos + 1:] = x[0, insert_pos:-1].clone()
|
| 1213 |
+
x[0, insert_pos] = end_think_token_id
|
| 1214 |
+
|
| 1215 |
+
total_accept_token += accept_cnt
|
| 1216 |
+
|
| 1217 |
+
if total_accept_token >= max_new_tokens:
|
| 1218 |
+
break
|
| 1219 |
+
|
| 1220 |
+
return x[:, : -(block_length * 2)], nfe
|
| 1221 |
+
|
| 1222 |
+
|
| 1223 |
+
@torch.no_grad()
|
| 1224 |
+
def linear_spec_generate(
|
| 1225 |
+
self,
|
| 1226 |
+
prompt_ids: torch.Tensor,
|
| 1227 |
+
max_new_tokens: int = 128,
|
| 1228 |
+
block_length: int = 32,
|
| 1229 |
+
temperature: float = 0.0,
|
| 1230 |
+
mask_token_id: Optional[int] = None,
|
| 1231 |
+
eos_token_id: Optional[int] = None,
|
| 1232 |
+
max_thinking_tokens: Optional[int] = None,
|
| 1233 |
+
end_think_token_id: Optional[int] = None,
|
| 1234 |
+
threshold: float = 0.0,
|
| 1235 |
+
):
|
| 1236 |
+
"""Linear speculative decoding: diffusion draft + AR verification.
|
| 1237 |
+
|
| 1238 |
+
Each step:
|
| 1239 |
+
1. Draft: forward [last_accepted, mask, ...] with bidirectional attention
|
| 1240 |
+
(diffusion_lm=True, use_cache=False). Shift AR logits to get
|
| 1241 |
+
per-position predictions; apply confidence filtering.
|
| 1242 |
+
2. Verify: forward the drafted block with causal attention
|
| 1243 |
+
(diffusion_lm=False, use_cache=True, use_causal_mask=True).
|
| 1244 |
+
Accept consecutive AR-matching tokens plus one bonus token.
|
| 1245 |
+
|
| 1246 |
+
Args:
|
| 1247 |
+
prompt_ids: Input token IDs of shape (1, prompt_len).
|
| 1248 |
+
max_new_tokens: Maximum number of tokens to generate.
|
| 1249 |
+
block_length: Number of tokens per draft/verify block.
|
| 1250 |
+
temperature: Sampling temperature (0 = greedy).
|
| 1251 |
+
mask_token_id: Override for config.mask_token_id.
|
| 1252 |
+
eos_token_id: Override for config.eos_token_id.
|
| 1253 |
+
max_thinking_tokens: Budget for thinking tokens before forcing end_think.
|
| 1254 |
+
end_think_token_id: Token ID inserted when thinking budget is exceeded.
|
| 1255 |
+
threshold: Confidence threshold for accepting draft predictions.
|
| 1256 |
+
|
| 1257 |
+
Returns:
|
| 1258 |
+
(output_ids, nfe): output_ids includes the prompt; nfe is the number
|
| 1259 |
+
of forward evaluations (matching self_spec_generate interface).
|
| 1260 |
+
"""
|
| 1261 |
+
if prompt_ids.shape[0] != 1:
|
| 1262 |
+
raise ValueError("Linear speculative decoding requires batch_size == 1")
|
| 1263 |
+
|
| 1264 |
+
token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
|
| 1265 |
+
if eos_token_id is None:
|
| 1266 |
+
eos_token_id = getattr(self.config, "eos_token_id", None)
|
| 1267 |
+
|
| 1268 |
+
device = prompt_ids.device
|
| 1269 |
+
prompt_len = prompt_ids.shape[1]
|
| 1270 |
+
dream_style = getattr(self.config, 'dlm_type', 'llada') == 'dream'
|
| 1271 |
+
|
| 1272 |
+
def _set_diffusion_lm(val: bool):
|
| 1273 |
+
for layer in self.encoder.layers:
|
| 1274 |
+
if hasattr(layer.self_attn, 'diffusion_lm'):
|
| 1275 |
+
layer.self_attn.diffusion_lm = val
|
| 1276 |
+
|
| 1277 |
+
# ===== Prefill (causal) =====
|
| 1278 |
+
_set_diffusion_lm(False)
|
| 1279 |
+
|
| 1280 |
+
enc_out = self.encoder(
|
| 1281 |
+
input_ids=prompt_ids,
|
| 1282 |
+
past_key_values=DynamicCache(),
|
| 1283 |
+
use_cache=True,
|
| 1284 |
+
use_causal_mask=True,
|
| 1285 |
+
)
|
| 1286 |
+
past_key_values = enc_out.past_key_values
|
| 1287 |
+
last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
|
| 1288 |
+
nfe = 1
|
| 1289 |
+
|
| 1290 |
+
if temperature > 0:
|
| 1291 |
+
probs = torch.softmax(last_logit / temperature, dim=-1)
|
| 1292 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 1293 |
+
else:
|
| 1294 |
+
next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
|
| 1295 |
+
|
| 1296 |
+
if eos_token_id is not None and next_token.item() == eos_token_id:
|
| 1297 |
+
output_ids = torch.cat([prompt_ids, next_token], dim=1)
|
| 1298 |
+
return output_ids, nfe
|
| 1299 |
+
|
| 1300 |
+
generated = [next_token]
|
| 1301 |
+
total_gen = 1
|
| 1302 |
+
|
| 1303 |
+
# ===== Main loop =====
|
| 1304 |
+
while total_gen < max_new_tokens:
|
| 1305 |
+
cache_len = past_key_values.get_seq_length()
|
| 1306 |
+
|
| 1307 |
+
block = torch.full(
|
| 1308 |
+
(1, block_length), token_mask_id, dtype=torch.long, device=device
|
| 1309 |
+
)
|
| 1310 |
+
block[0, 0] = next_token.item()
|
| 1311 |
+
|
| 1312 |
+
# -------- Draft (bidirectional, don't update cache) --------
|
| 1313 |
+
_set_diffusion_lm(True)
|
| 1314 |
+
while True:
|
| 1315 |
+
is_mask = block == token_mask_id
|
| 1316 |
+
if not is_mask.any():
|
| 1317 |
+
break
|
| 1318 |
+
|
| 1319 |
+
enc_out = self.encoder(
|
| 1320 |
+
input_ids=block,
|
| 1321 |
+
past_key_values=past_key_values,
|
| 1322 |
+
use_cache=False,
|
| 1323 |
+
)
|
| 1324 |
+
nfe += 1
|
| 1325 |
+
|
| 1326 |
+
draft_logits = self.diffusion_head(enc_out.last_hidden_state)
|
| 1327 |
+
if dream_style:
|
| 1328 |
+
# DREAM: logit[i] predicts position i+1 → shift to self-prediction
|
| 1329 |
+
draft_logits = torch.cat(
|
| 1330 |
+
[draft_logits[:, :1, :], draft_logits[:, :-1, :]], dim=1
|
| 1331 |
+
)
|
| 1332 |
+
# LLaDA: logit[i] already predicts position i → no shift needed
|
| 1333 |
+
|
| 1334 |
+
if temperature > 0:
|
| 1335 |
+
draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
|
| 1336 |
+
draft_tokens = torch.multinomial(
|
| 1337 |
+
draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1
|
| 1338 |
+
).view(1, block_length)
|
| 1339 |
+
else:
|
| 1340 |
+
draft_tokens = draft_logits.argmax(dim=-1)
|
| 1341 |
+
draft_probs = torch.softmax(draft_logits, dim=-1)
|
| 1342 |
+
|
| 1343 |
+
if threshold > 0:
|
| 1344 |
+
draft_conf = torch.gather(
|
| 1345 |
+
draft_probs, -1, draft_tokens.unsqueeze(-1)
|
| 1346 |
+
).squeeze(-1)
|
| 1347 |
+
draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
|
| 1348 |
+
unmask = draft_conf >= threshold
|
| 1349 |
+
|
| 1350 |
+
# Ensure each iteration makes progress even when every masked
|
| 1351 |
+
# position falls below the confidence threshold.
|
| 1352 |
+
if not unmask.any():
|
| 1353 |
+
best_idx = draft_conf.view(-1).argmax()
|
| 1354 |
+
unmask = torch.zeros_like(is_mask, dtype=torch.bool)
|
| 1355 |
+
unmask.view(-1)[best_idx] = True
|
| 1356 |
+
|
| 1357 |
+
block[unmask] = draft_tokens[unmask]
|
| 1358 |
+
else:
|
| 1359 |
+
block[is_mask] = draft_tokens[is_mask]
|
| 1360 |
+
break
|
| 1361 |
+
|
| 1362 |
+
# -------- Verify (causal, update cache) --------
|
| 1363 |
+
_set_diffusion_lm(False)
|
| 1364 |
+
enc_out = self.encoder(
|
| 1365 |
+
input_ids=block,
|
| 1366 |
+
past_key_values=past_key_values,
|
| 1367 |
+
use_cache=True,
|
| 1368 |
+
use_causal_mask=True,
|
| 1369 |
+
)
|
| 1370 |
+
past_key_values = enc_out.past_key_values
|
| 1371 |
+
nfe += 1
|
| 1372 |
+
|
| 1373 |
+
verify_logits = self.diffusion_head(enc_out.last_hidden_state)
|
| 1374 |
+
if temperature > 0:
|
| 1375 |
+
verify_probs = torch.softmax(verify_logits / temperature, dim=-1)
|
| 1376 |
+
ar_tokens = torch.multinomial(
|
| 1377 |
+
verify_probs.view(-1, verify_probs.shape[-1]), num_samples=1
|
| 1378 |
+
).view(1, block_length)
|
| 1379 |
+
else:
|
| 1380 |
+
ar_tokens = verify_logits.argmax(dim=-1)
|
| 1381 |
+
|
| 1382 |
+
accepted = 0
|
| 1383 |
+
for i in range(block_length - 1):
|
| 1384 |
+
if ar_tokens[0, i].item() == block[0, i + 1].item():
|
| 1385 |
+
accepted += 1
|
| 1386 |
+
else:
|
| 1387 |
+
break
|
| 1388 |
+
accepted += 1 # bonus token from AR verification
|
| 1389 |
+
|
| 1390 |
+
accepted_toks = ar_tokens[:, :accepted]
|
| 1391 |
+
generated.append(accepted_toks)
|
| 1392 |
+
total_gen += accepted
|
| 1393 |
+
|
| 1394 |
+
_crop_dynamic_cache(past_key_values, cache_len + accepted)
|
| 1395 |
+
|
| 1396 |
+
next_token = ar_tokens[:, accepted - 1 : accepted]
|
| 1397 |
+
|
| 1398 |
+
# -------- EOS check --------
|
| 1399 |
+
if eos_token_id is not None:
|
| 1400 |
+
eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
|
| 1401 |
+
if len(eos_pos) > 0:
|
| 1402 |
+
first_eos = eos_pos[0].item()
|
| 1403 |
+
generated[-1] = accepted_toks[:, : first_eos + 1]
|
| 1404 |
+
total_gen = total_gen - accepted + first_eos + 1
|
| 1405 |
+
break
|
| 1406 |
+
|
| 1407 |
+
# -------- Thinking budget enforcement --------
|
| 1408 |
+
if end_think_token_id is not None and max_thinking_tokens is not None:
|
| 1409 |
+
if total_gen > max_thinking_tokens:
|
| 1410 |
+
all_gen = torch.cat(generated, dim=1)
|
| 1411 |
+
if not (all_gen == end_think_token_id).any():
|
| 1412 |
+
next_token = torch.tensor(
|
| 1413 |
+
[[end_think_token_id]], device=device
|
| 1414 |
+
)
|
| 1415 |
+
|
| 1416 |
+
if total_gen >= max_new_tokens:
|
| 1417 |
+
break
|
| 1418 |
+
|
| 1419 |
+
all_generated = torch.cat(generated, dim=1)
|
| 1420 |
+
output_ids = torch.cat([prompt_ids, all_generated], dim=1)
|
| 1421 |
+
|
| 1422 |
+
return output_ids, nfe
|
| 1423 |
+
|
| 1424 |
+
|
| 1425 |
+
@torch.no_grad()
|
| 1426 |
+
def linear_spec_generate_mp(
|
| 1427 |
+
self,
|
| 1428 |
+
prompt_ids: torch.Tensor,
|
| 1429 |
+
max_new_tokens: int = 512,
|
| 1430 |
+
block_length: int = 32,
|
| 1431 |
+
temperature: float = 0.0,
|
| 1432 |
+
mask_token_id: Optional[int] = None,
|
| 1433 |
+
eos_token_id: Optional[int] = None,
|
| 1434 |
+
max_paths: int = 16,
|
| 1435 |
+
uncertain_threshold: float = 0.7,
|
| 1436 |
+
top_k_candidates: int = 2,
|
| 1437 |
+
threshold: float = 0.0,
|
| 1438 |
+
max_thinking_tokens: Optional[int] = None,
|
| 1439 |
+
end_think_token_id: Optional[int] = None,
|
| 1440 |
+
):
|
| 1441 |
+
"""Linear speculative decoding with multi-path tree verification.
|
| 1442 |
+
|
| 1443 |
+
Self-contained method — no external file dependencies beyond the model itself.
|
| 1444 |
+
|
| 1445 |
+
Each iteration costs 2 NFE (1 draft + 1 verify):
|
| 1446 |
+
1. Draft: single-step bidirectional diffusion fills a block of masks.
|
| 1447 |
+
2. Verify: tree-structured AR verification with multiple candidate paths.
|
| 1448 |
+
|
| 1449 |
+
Multi-path verification identifies low-confidence draft positions and
|
| 1450 |
+
explores top-k alternative tokens. All candidate paths share a trie
|
| 1451 |
+
prefix and are verified in one forward pass via a 4D tree-ancestry
|
| 1452 |
+
attention mask (~40 tokens), picking the path with the longest
|
| 1453 |
+
accepted prefix.
|
| 1454 |
+
|
| 1455 |
+
Benchmark results (NeMo Skills prompt, enable_thinking=False):
|
| 1456 |
+
GSM8K bl=32: +17.1% UW-TPF vs vanilla (acc 93.9%)
|
| 1457 |
+
MBPP bl=64: +17.8% UW-TPF vs vanilla (pass@1 78.2%)
|
| 1458 |
+
|
| 1459 |
+
Args:
|
| 1460 |
+
prompt_ids: (1, prompt_len) input token IDs.
|
| 1461 |
+
max_new_tokens: Maximum tokens to generate.
|
| 1462 |
+
block_length: Draft block size. Use 32 for math, 64 for code.
|
| 1463 |
+
temperature: Sampling temperature (0.0 = greedy).
|
| 1464 |
+
eos_token_id: Stop token ID.
|
| 1465 |
+
max_paths: Tree verification budget. 16 = up to 4 uncertain
|
| 1466 |
+
positions x 2 candidates each.
|
| 1467 |
+
uncertain_threshold: Confidence below which a position is
|
| 1468 |
+
considered uncertain and expanded with alternatives.
|
| 1469 |
+
top_k_candidates: Number of alternative tokens to try at each
|
| 1470 |
+
uncertain position.
|
| 1471 |
+
|
| 1472 |
+
Returns:
|
| 1473 |
+
output_ids: (1, prompt_len + generated_len) full sequence.
|
| 1474 |
+
nfe: Total number of forward evaluations.
|
| 1475 |
+
"""
|
| 1476 |
+
from itertools import product as _product
|
| 1477 |
+
|
| 1478 |
+
if prompt_ids.shape[0] != 1:
|
| 1479 |
+
raise ValueError("Requires batch_size == 1")
|
| 1480 |
+
|
| 1481 |
+
device = prompt_ids.device
|
| 1482 |
+
token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
|
| 1483 |
+
if eos_token_id is None:
|
| 1484 |
+
eos_token_id = getattr(self.config, "eos_token_id", None)
|
| 1485 |
+
|
| 1486 |
+
def _set_dlm(val: bool):
|
| 1487 |
+
for layer in self.encoder.layers:
|
| 1488 |
+
if hasattr(layer.self_attn, 'diffusion_lm'):
|
| 1489 |
+
layer.self_attn.diffusion_lm = val
|
| 1490 |
+
|
| 1491 |
+
def _crop_cache(kv, length):
|
| 1492 |
+
# transformers 4.55 exposes .key_cache/.value_cache lists; 5.0 moved them under .layers[i].keys/.values.
|
| 1493 |
+
for li in range(len(kv)):
|
| 1494 |
+
if hasattr(kv, 'layers'):
|
| 1495 |
+
layer = kv.layers[li]
|
| 1496 |
+
layer.keys = layer.keys[:, :, :length]
|
| 1497 |
+
layer.values = layer.values[:, :, :length]
|
| 1498 |
+
else:
|
| 1499 |
+
kv.key_cache[li] = kv.key_cache[li][:, :, :length]
|
| 1500 |
+
kv.value_cache[li] = kv.value_cache[li][:, :, :length]
|
| 1501 |
+
kv._seen_tokens = length
|
| 1502 |
+
|
| 1503 |
+
# ----- tree verify helpers (inlined) -----
|
| 1504 |
+
|
| 1505 |
+
def _mp_verify(block, draft_probs, draft_conf, past_kv, cache_len):
|
| 1506 |
+
"""Multi-path verify via batch-stacking (flash-attention compatible).
|
| 1507 |
+
|
| 1508 |
+
Unlike tree attention (4D mask), batch-stacking expands the KV cache
|
| 1509 |
+
batch dimension and runs all candidate paths as separate batch entries.
|
| 1510 |
+
This keeps flash attention + GQA enabled, avoiding OOM from the 4D
|
| 1511 |
+
mask path which disables both.
|
| 1512 |
+
|
| 1513 |
+
Returns (accepted_toks, n_accepted, past_kv, next_tok) or None.
|
| 1514 |
+
"""
|
| 1515 |
+
bl = block.shape[1]
|
| 1516 |
+
|
| 1517 |
+
# Identify uncertain positions
|
| 1518 |
+
is_filled = block[0] != token_mask_id
|
| 1519 |
+
pos_conf = torch.zeros(bl, device=device)
|
| 1520 |
+
pos_conf[0] = float('inf')
|
| 1521 |
+
for p in range(1, bl):
|
| 1522 |
+
if is_filled[p]:
|
| 1523 |
+
c = draft_conf[0, p].item()
|
| 1524 |
+
pos_conf[p] = c if c != float('-inf') else float('inf')
|
| 1525 |
+
else:
|
| 1526 |
+
pos_conf[p] = float('-inf')
|
| 1527 |
+
|
| 1528 |
+
unc_mask = (pos_conf < uncertain_threshold) & (pos_conf > float('-inf'))
|
| 1529 |
+
unc_pos = unc_mask.nonzero(as_tuple=True)[0].tolist()
|
| 1530 |
+
if not unc_pos:
|
| 1531 |
+
return None
|
| 1532 |
+
|
| 1533 |
+
import math as _math
|
| 1534 |
+
max_unc = min(len(unc_pos), max(1, int(_math.log2(max_paths))))
|
| 1535 |
+
unc_pos = sorted(unc_pos)[:max_unc]
|
| 1536 |
+
|
| 1537 |
+
# Build candidate blocks
|
| 1538 |
+
topk_at = {}
|
| 1539 |
+
for p in unc_pos:
|
| 1540 |
+
_, ids = draft_probs[0, p].topk(top_k_candidates)
|
| 1541 |
+
topk_at[p] = ids.tolist()
|
| 1542 |
+
|
| 1543 |
+
combos = list(_product(*(topk_at[p] for p in sorted(topk_at))))[:max_paths]
|
| 1544 |
+
num_paths = len(combos)
|
| 1545 |
+
if num_paths <= 1:
|
| 1546 |
+
return None
|
| 1547 |
+
|
| 1548 |
+
candidate_blocks = block.expand(num_paths, -1).clone()
|
| 1549 |
+
pos_list = sorted(topk_at.keys())
|
| 1550 |
+
for pi, combo in enumerate(combos):
|
| 1551 |
+
for ci, p in enumerate(pos_list):
|
| 1552 |
+
candidate_blocks[pi, p] = combo[ci]
|
| 1553 |
+
|
| 1554 |
+
# Expand KV cache batch dimension (shared, no copy)
|
| 1555 |
+
for li in range(len(past_kv)):
|
| 1556 |
+
if hasattr(past_kv, 'layers'):
|
| 1557 |
+
layer = past_kv.layers[li]
|
| 1558 |
+
layer.keys = layer.keys.expand(num_paths, -1, -1, -1)
|
| 1559 |
+
layer.values = layer.values.expand(num_paths, -1, -1, -1)
|
| 1560 |
+
else:
|
| 1561 |
+
past_kv.key_cache[li] = past_kv.key_cache[li].expand(num_paths, -1, -1, -1)
|
| 1562 |
+
past_kv.value_cache[li] = past_kv.value_cache[li].expand(num_paths, -1, -1, -1)
|
| 1563 |
+
|
| 1564 |
+
# Batched causal verify — uses flash attention + GQA
|
| 1565 |
+
_set_dlm(False)
|
| 1566 |
+
enc_out = self.encoder(
|
| 1567 |
+
input_ids=candidate_blocks,
|
| 1568 |
+
past_key_values=past_kv,
|
| 1569 |
+
use_cache=True,
|
| 1570 |
+
use_causal_mask=True,
|
| 1571 |
+
)
|
| 1572 |
+
past_kv = enc_out.past_key_values
|
| 1573 |
+
vlogits = self.diffusion_head(enc_out.last_hidden_state)
|
| 1574 |
+
|
| 1575 |
+
if temperature > 0:
|
| 1576 |
+
vp = torch.softmax(vlogits / temperature, dim=-1)
|
| 1577 |
+
ar_tokens = torch.multinomial(vp.view(-1, vp.shape[-1]), 1).view(num_paths, bl)
|
| 1578 |
+
else:
|
| 1579 |
+
ar_tokens = vlogits.argmax(dim=-1)
|
| 1580 |
+
|
| 1581 |
+
# Find best path (longest accepted prefix)
|
| 1582 |
+
best_acc, best_pidx = 0, 0
|
| 1583 |
+
for pi in range(num_paths):
|
| 1584 |
+
acc = 0
|
| 1585 |
+
for i in range(bl - 1):
|
| 1586 |
+
if ar_tokens[pi, i].item() == candidate_blocks[pi, i + 1].item():
|
| 1587 |
+
acc += 1
|
| 1588 |
+
else:
|
| 1589 |
+
break
|
| 1590 |
+
acc += 1
|
| 1591 |
+
if acc > best_acc:
|
| 1592 |
+
best_acc, best_pidx = acc, pi
|
| 1593 |
+
|
| 1594 |
+
accepted_toks = ar_tokens[best_pidx:best_pidx+1, :best_acc]
|
| 1595 |
+
|
| 1596 |
+
# Extract winning path's KV cache slice
|
| 1597 |
+
for li in range(len(past_kv)):
|
| 1598 |
+
if hasattr(past_kv, 'layers'):
|
| 1599 |
+
layer = past_kv.layers[li]
|
| 1600 |
+
layer.keys = layer.keys[best_pidx:best_pidx+1].contiguous()
|
| 1601 |
+
layer.values = layer.values[best_pidx:best_pidx+1].contiguous()
|
| 1602 |
+
else:
|
| 1603 |
+
past_kv.key_cache[li] = past_kv.key_cache[li][best_pidx:best_pidx+1].contiguous()
|
| 1604 |
+
past_kv.value_cache[li] = past_kv.value_cache[li][best_pidx:best_pidx+1].contiguous()
|
| 1605 |
+
_crop_cache(past_kv, cache_len + best_acc)
|
| 1606 |
+
|
| 1607 |
+
return accepted_toks, best_acc, past_kv, accepted_toks[:, -1:]
|
| 1608 |
+
|
| 1609 |
+
# ── Prefill (causal) ──
|
| 1610 |
+
_set_dlm(False)
|
| 1611 |
+
enc_out = self.encoder(
|
| 1612 |
+
input_ids=prompt_ids, past_key_values=DynamicCache(),
|
| 1613 |
+
use_cache=True, use_causal_mask=True,
|
| 1614 |
+
)
|
| 1615 |
+
past_key_values = enc_out.past_key_values
|
| 1616 |
+
last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
|
| 1617 |
+
nfe = 1
|
| 1618 |
+
|
| 1619 |
+
if temperature > 0:
|
| 1620 |
+
next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), 1)
|
| 1621 |
+
else:
|
| 1622 |
+
next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
|
| 1623 |
+
|
| 1624 |
+
if eos_token_id is not None and next_token.item() == eos_token_id:
|
| 1625 |
+
return torch.cat([prompt_ids, next_token], dim=1), nfe
|
| 1626 |
+
|
| 1627 |
+
generated = [next_token]
|
| 1628 |
+
total_gen = 1
|
| 1629 |
+
|
| 1630 |
+
# ── Main draft-verify loop ──
|
| 1631 |
+
while total_gen < max_new_tokens:
|
| 1632 |
+
cache_len = past_key_values.get_seq_length()
|
| 1633 |
+
|
| 1634 |
+
block = torch.full((1, block_length), token_mask_id, dtype=torch.long, device=device)
|
| 1635 |
+
block[0, 0] = next_token.item()
|
| 1636 |
+
|
| 1637 |
+
# Draft: single-step bidirectional diffusion (1 NFE)
|
| 1638 |
+
_set_dlm(True)
|
| 1639 |
+
enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=False)
|
| 1640 |
+
nfe += 1
|
| 1641 |
+
|
| 1642 |
+
draft_logits = self.diffusion_head(enc_out.last_hidden_state)
|
| 1643 |
+
if temperature > 0:
|
| 1644 |
+
draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
|
| 1645 |
+
draft_tokens = torch.multinomial(
|
| 1646 |
+
draft_probs.view(-1, draft_probs.shape[-1]), 1
|
| 1647 |
+
).view(1, block_length)
|
| 1648 |
+
else:
|
| 1649 |
+
draft_tokens = draft_logits.argmax(dim=-1)
|
| 1650 |
+
draft_probs = torch.softmax(draft_logits, dim=-1)
|
| 1651 |
+
|
| 1652 |
+
draft_conf = torch.gather(draft_probs, -1, draft_tokens.unsqueeze(-1)).squeeze(-1)
|
| 1653 |
+
is_mask = block == token_mask_id
|
| 1654 |
+
draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
|
| 1655 |
+
block[is_mask] = draft_tokens[is_mask]
|
| 1656 |
+
|
| 1657 |
+
# Verify: multi-path batch-stacking (1 NFE, flash-attention compatible)
|
| 1658 |
+
result = _mp_verify(block, draft_probs, draft_conf, past_key_values, cache_len)
|
| 1659 |
+
|
| 1660 |
+
if result is not None:
|
| 1661 |
+
accepted_toks, accepted, past_key_values, next_token = result
|
| 1662 |
+
nfe += 1
|
| 1663 |
+
else:
|
| 1664 |
+
# No uncertain positions — single-path causal verify
|
| 1665 |
+
_set_dlm(False)
|
| 1666 |
+
enc_out = self.encoder(
|
| 1667 |
+
input_ids=block, past_key_values=past_key_values,
|
| 1668 |
+
use_cache=True, use_causal_mask=True,
|
| 1669 |
+
)
|
| 1670 |
+
past_key_values = enc_out.past_key_values
|
| 1671 |
+
nfe += 1
|
| 1672 |
+
|
| 1673 |
+
vlogits = self.diffusion_head(enc_out.last_hidden_state)
|
| 1674 |
+
if temperature > 0:
|
| 1675 |
+
vp = torch.softmax(vlogits / temperature, dim=-1)
|
| 1676 |
+
ar_tokens = torch.multinomial(vp.view(-1, vp.shape[-1]), 1).view(1, block_length)
|
| 1677 |
+
else:
|
| 1678 |
+
ar_tokens = vlogits.argmax(dim=-1)
|
| 1679 |
+
|
| 1680 |
+
accepted = 0
|
| 1681 |
+
for i in range(block_length - 1):
|
| 1682 |
+
if ar_tokens[0, i].item() == block[0, i + 1].item():
|
| 1683 |
+
accepted += 1
|
| 1684 |
+
else:
|
| 1685 |
+
break
|
| 1686 |
+
accepted += 1
|
| 1687 |
+
|
| 1688 |
+
accepted_toks = ar_tokens[:, :accepted]
|
| 1689 |
+
_crop_cache(past_key_values, cache_len + accepted)
|
| 1690 |
+
next_token = ar_tokens[:, accepted - 1 : accepted]
|
| 1691 |
+
|
| 1692 |
+
generated.append(accepted_toks)
|
| 1693 |
+
total_gen += accepted
|
| 1694 |
+
|
| 1695 |
+
if eos_token_id is not None:
|
| 1696 |
+
eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
|
| 1697 |
+
if len(eos_pos) > 0:
|
| 1698 |
+
first_eos = eos_pos[0].item()
|
| 1699 |
+
generated[-1] = accepted_toks[:, :first_eos + 1]
|
| 1700 |
+
total_gen = total_gen - accepted + first_eos + 1
|
| 1701 |
+
break
|
| 1702 |
+
|
| 1703 |
+
if end_think_token_id is not None and max_thinking_tokens is not None:
|
| 1704 |
+
if total_gen > max_thinking_tokens:
|
| 1705 |
+
all_gen = torch.cat(generated, dim=1)
|
| 1706 |
+
if not (all_gen == end_think_token_id).any():
|
| 1707 |
+
next_token = torch.tensor(
|
| 1708 |
+
[[end_think_token_id]], device=device
|
| 1709 |
+
)
|
| 1710 |
+
|
| 1711 |
+
if total_gen >= max_new_tokens:
|
| 1712 |
+
break
|
| 1713 |
+
|
| 1714 |
+
all_generated = torch.cat(generated, dim=1)
|
| 1715 |
+
output_ids = torch.cat([prompt_ids, all_generated], dim=1)
|
| 1716 |
+
return output_ids, nfe
|
| 1717 |
+
|
| 1718 |
+
|
| 1719 |
+
@torch.no_grad()
|
| 1720 |
+
def linear_spec_generate_lora(
|
| 1721 |
+
self,
|
| 1722 |
+
prompt_ids: torch.Tensor,
|
| 1723 |
+
max_new_tokens: int = 128,
|
| 1724 |
+
block_length: int = 32,
|
| 1725 |
+
temperature: float = 0.0,
|
| 1726 |
+
mask_token_id: Optional[int] = None,
|
| 1727 |
+
eos_token_id: Optional[int] = None,
|
| 1728 |
+
threshold: float = 0.0,
|
| 1729 |
+
rebuild_kv: str = 'none',
|
| 1730 |
+
max_thinking_tokens: Optional[int] = None,
|
| 1731 |
+
end_think_token_id: Optional[int] = None,
|
| 1732 |
+
):
|
| 1733 |
+
"""Linear speculative decoding: diffusion draft + AR verify.
|
| 1734 |
+
LoRA adapter toggling: ON for draft (bidirectional), OFF for verify (causal).
|
| 1735 |
+
Returns (output_ids, nfe).
|
| 1736 |
+
"""
|
| 1737 |
+
if prompt_ids.shape[0] != 1:
|
| 1738 |
+
raise ValueError("linear_spec_generate requires batch_size == 1")
|
| 1739 |
+
|
| 1740 |
+
token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
|
| 1741 |
+
if eos_token_id is None:
|
| 1742 |
+
eos_token_id = getattr(self.config, "eos_token_id", None)
|
| 1743 |
+
|
| 1744 |
+
device = prompt_ids.device
|
| 1745 |
+
dream_style = getattr(self.config, 'dlm_type', 'llada') == 'dream'
|
| 1746 |
+
|
| 1747 |
+
def _set_diffusion_lm(val: bool):
|
| 1748 |
+
for layer in self.encoder.layers:
|
| 1749 |
+
if hasattr(layer.self_attn, 'diffusion_lm'):
|
| 1750 |
+
layer.self_attn.diffusion_lm = val
|
| 1751 |
+
|
| 1752 |
+
def _toggle_adapters(model, enable: bool):
|
| 1753 |
+
for module in model.modules():
|
| 1754 |
+
if hasattr(module, '_disable_adapters'):
|
| 1755 |
+
module._disable_adapters = not enable
|
| 1756 |
+
|
| 1757 |
+
# Prefill (causal, LoRA OFF)
|
| 1758 |
+
_set_diffusion_lm(False)
|
| 1759 |
+
_toggle_adapters(self, False)
|
| 1760 |
+
enc_out = self.encoder(
|
| 1761 |
+
input_ids=prompt_ids,
|
| 1762 |
+
past_key_values=DynamicCache(),
|
| 1763 |
+
use_cache=True,
|
| 1764 |
+
use_causal_mask=True,
|
| 1765 |
+
)
|
| 1766 |
+
past_key_values = enc_out.past_key_values
|
| 1767 |
+
last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
|
| 1768 |
+
nfe = 1
|
| 1769 |
+
|
| 1770 |
+
if temperature > 0:
|
| 1771 |
+
next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), num_samples=1)
|
| 1772 |
+
else:
|
| 1773 |
+
next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
|
| 1774 |
+
|
| 1775 |
+
if eos_token_id is not None and next_token.item() == eos_token_id:
|
| 1776 |
+
return torch.cat([prompt_ids, next_token], dim=1), nfe
|
| 1777 |
+
|
| 1778 |
+
generated = [next_token]
|
| 1779 |
+
total_gen = 1
|
| 1780 |
+
|
| 1781 |
+
while total_gen < max_new_tokens:
|
| 1782 |
+
cache_len = past_key_values.get_seq_length()
|
| 1783 |
+
|
| 1784 |
+
block = torch.full((1, block_length), token_mask_id, dtype=torch.long, device=device)
|
| 1785 |
+
block[0, 0] = next_token.item()
|
| 1786 |
+
|
| 1787 |
+
# Draft (bidirectional, LoRA ON)
|
| 1788 |
+
_set_diffusion_lm(True)
|
| 1789 |
+
_toggle_adapters(self, True)
|
| 1790 |
+
enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=False)
|
| 1791 |
+
nfe += 1
|
| 1792 |
+
|
| 1793 |
+
draft_logits = self.diffusion_head(enc_out.last_hidden_state)
|
| 1794 |
+
if dream_style:
|
| 1795 |
+
draft_logits = torch.cat([draft_logits[:, :1, :], draft_logits[:, :-1, :]], dim=1)
|
| 1796 |
+
|
| 1797 |
+
if temperature > 0:
|
| 1798 |
+
draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
|
| 1799 |
+
draft_tokens = torch.multinomial(draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1).view(1, block_length)
|
| 1800 |
+
else:
|
| 1801 |
+
draft_tokens = draft_logits.argmax(dim=-1)
|
| 1802 |
+
draft_probs = torch.softmax(draft_logits, dim=-1)
|
| 1803 |
+
|
| 1804 |
+
draft_conf = torch.gather(draft_probs, -1, draft_tokens.unsqueeze(-1)).squeeze(-1)
|
| 1805 |
+
is_mask = block == token_mask_id
|
| 1806 |
+
draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
|
| 1807 |
+
unmask = draft_conf > threshold
|
| 1808 |
+
if unmask.sum() > 0:
|
| 1809 |
+
block[unmask] = draft_tokens[unmask]
|
| 1810 |
+
|
| 1811 |
+
# Verify (causal, LoRA OFF)
|
| 1812 |
+
_set_diffusion_lm(False)
|
| 1813 |
+
_toggle_adapters(self, False)
|
| 1814 |
+
enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=True, use_causal_mask=True)
|
| 1815 |
+
past_key_values = enc_out.past_key_values
|
| 1816 |
+
nfe += 1
|
| 1817 |
+
|
| 1818 |
+
verify_logits = self.diffusion_head(enc_out.last_hidden_state)
|
| 1819 |
+
if temperature > 0:
|
| 1820 |
+
ar_tokens = torch.multinomial(torch.softmax(verify_logits / temperature, dim=-1).view(-1, verify_logits.shape[-1]), num_samples=1).view(1, block_length)
|
| 1821 |
+
else:
|
| 1822 |
+
ar_tokens = verify_logits.argmax(dim=-1)
|
| 1823 |
+
|
| 1824 |
+
accepted = 0
|
| 1825 |
+
for i in range(block_length - 1):
|
| 1826 |
+
if ar_tokens[0, i].item() == block[0, i + 1].item():
|
| 1827 |
+
accepted += 1
|
| 1828 |
+
else:
|
| 1829 |
+
break
|
| 1830 |
+
accepted += 1 # bonus token
|
| 1831 |
+
|
| 1832 |
+
accepted_toks = ar_tokens[:, :accepted]
|
| 1833 |
+
generated.append(accepted_toks)
|
| 1834 |
+
total_gen += accepted
|
| 1835 |
+
|
| 1836 |
+
_crop_dynamic_cache(past_key_values, cache_len + accepted)
|
| 1837 |
+
next_token = ar_tokens[:, accepted - 1 : accepted]
|
| 1838 |
+
|
| 1839 |
+
# EOS check
|
| 1840 |
+
if eos_token_id is not None:
|
| 1841 |
+
eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
|
| 1842 |
+
if len(eos_pos) > 0:
|
| 1843 |
+
first_eos = eos_pos[0].item()
|
| 1844 |
+
generated[-1] = accepted_toks[:, : first_eos + 1]
|
| 1845 |
+
total_gen = total_gen - accepted + first_eos + 1
|
| 1846 |
+
break
|
| 1847 |
+
|
| 1848 |
+
# Thinking budget enforcement
|
| 1849 |
+
if end_think_token_id is not None and max_thinking_tokens is not None:
|
| 1850 |
+
if total_gen > max_thinking_tokens:
|
| 1851 |
+
all_gen = torch.cat(generated, dim=1)
|
| 1852 |
+
if not (all_gen == end_think_token_id).any():
|
| 1853 |
+
next_token = torch.tensor([[end_think_token_id]], device=device)
|
| 1854 |
+
|
| 1855 |
+
if total_gen >= max_new_tokens:
|
| 1856 |
+
break
|
| 1857 |
+
|
| 1858 |
+
all_generated = torch.cat(generated, dim=1)
|
| 1859 |
+
output_ids = torch.cat([prompt_ids, all_generated], dim=1)
|
| 1860 |
+
return output_ids, nfe
|
modeling_nemotron_labs_diffusion.py
DELETED
|
@@ -1,870 +0,0 @@
|
|
| 1 |
-
import copy
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
-
from typing import Optional, Tuple
|
| 4 |
-
import numpy as np
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
from torch import nn
|
| 9 |
-
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
|
| 10 |
-
from transformers.utils import ModelOutput
|
| 11 |
-
|
| 12 |
-
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
| 13 |
-
|
| 14 |
-
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 15 |
-
|
| 16 |
-
from transformers.processing_utils import Unpack
|
| 17 |
-
|
| 18 |
-
from transformers.cache_utils import Cache, DynamicCache
|
| 19 |
-
|
| 20 |
-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 21 |
-
|
| 22 |
-
from transformers.generation import GenerationMixin
|
| 23 |
-
|
| 24 |
-
import math
|
| 25 |
-
|
| 26 |
-
from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
|
| 27 |
-
from .configuration_nemotron_labs_diffusion import NemotronLabsDiffusionConfig
|
| 28 |
-
|
| 29 |
-
__all__ = ["NemotronLabsDiffusionModel", "NemotronLabsDiffusionFlexAttention"]
|
| 30 |
-
|
| 31 |
-
@dataclass
|
| 32 |
-
class NemotronLabsDiffusionOutputWithPast(ModelOutput):
|
| 33 |
-
loss: torch.FloatTensor | None = None
|
| 34 |
-
logits: torch.FloatTensor | None = None
|
| 35 |
-
causal_logits: torch.FloatTensor | None = None
|
| 36 |
-
past_key_values: Cache | None = None
|
| 37 |
-
hidden_states: tuple[torch.FloatTensor, ...] | None = None
|
| 38 |
-
attentions: tuple[torch.FloatTensor, ...] | None = None
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
@torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs", dynamic=False)
|
| 42 |
-
def fused_flex_attention(q, k, v, block_mask=None):
|
| 43 |
-
return flex_attention(q, k, v, block_mask=block_mask)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class NemotronLabsDiffusionFlexAttention(Ministral3Attention):
|
| 47 |
-
def __init__(self, *args, **kwargs):
|
| 48 |
-
super().__init__(*args, **kwargs)
|
| 49 |
-
|
| 50 |
-
self.block_size = self.config.block_size
|
| 51 |
-
self.block_diff_mask = None
|
| 52 |
-
|
| 53 |
-
import torch._dynamo.config as dcfg
|
| 54 |
-
dcfg.cache_size_limit = 512
|
| 55 |
-
|
| 56 |
-
def compute_block_mask(self, mode, q_len, block_size=None):
|
| 57 |
-
|
| 58 |
-
def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
|
| 59 |
-
x0_flag_q = (q_idx >= n)
|
| 60 |
-
x0_flag_kv = (kv_idx >= n)
|
| 61 |
-
|
| 62 |
-
# Compute block indices
|
| 63 |
-
block_q = torch.where(x0_flag_q == 1,
|
| 64 |
-
(q_idx - n) // block_size,
|
| 65 |
-
q_idx // block_size)
|
| 66 |
-
block_kv = torch.where(x0_flag_kv == 1,
|
| 67 |
-
(kv_idx - n) // block_size,
|
| 68 |
-
kv_idx // block_size)
|
| 69 |
-
|
| 70 |
-
# **1. Block Diagonal Mask (M_BD) **
|
| 71 |
-
block_diagonal = (block_q == block_kv) & (x0_flag_kv == 0) & (x0_flag_q == 0)
|
| 72 |
-
|
| 73 |
-
# **2. Offset Block-Causal Mask (M_OBC) **
|
| 74 |
-
offset_block_causal = (
|
| 75 |
-
(block_q > block_kv)
|
| 76 |
-
& (x0_flag_kv == 1)
|
| 77 |
-
& (x0_flag_q == 0)
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
-
# **3. Fully Causal Mask (M_BC) **
|
| 81 |
-
fully_causal = (q_idx >= kv_idx) & (x0_flag_kv == 1) & (x0_flag_q == 1)
|
| 82 |
-
|
| 83 |
-
# **4. Combine Masks **
|
| 84 |
-
return block_diagonal | offset_block_causal | fully_causal
|
| 85 |
-
|
| 86 |
-
attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, q_len//2)
|
| 87 |
-
|
| 88 |
-
block_mask = create_block_mask(
|
| 89 |
-
attn_mask, B=None, H=None, Q_LEN=q_len, KV_LEN=q_len
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
return block_mask
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def forward(
|
| 96 |
-
self,
|
| 97 |
-
hidden_states: torch.Tensor,
|
| 98 |
-
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
| 99 |
-
attention_mask: Optional[torch.Tensor],
|
| 100 |
-
past_key_values: Optional[Cache] = None,
|
| 101 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 102 |
-
is_training: bool = True,
|
| 103 |
-
**kwargs: Unpack[FlashAttentionKwargs],
|
| 104 |
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 105 |
-
bsz, q_len, _ = hidden_states.size()
|
| 106 |
-
input_shape = hidden_states.shape[:-1]
|
| 107 |
-
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 108 |
-
|
| 109 |
-
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 110 |
-
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 111 |
-
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 112 |
-
|
| 113 |
-
cos, sin = position_embeddings
|
| 114 |
-
|
| 115 |
-
if is_training:
|
| 116 |
-
# Split query and key states in half along sequence length dimension
|
| 117 |
-
q1, q2 = query_states.chunk(2, dim=2)
|
| 118 |
-
k1, k2 = key_states.chunk(2, dim=2)
|
| 119 |
-
|
| 120 |
-
# Apply RoPE independently to each half
|
| 121 |
-
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
| 122 |
-
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
| 123 |
-
|
| 124 |
-
# Recombine the halves
|
| 125 |
-
query_states = torch.cat([q1, q2], dim=2)
|
| 126 |
-
key_states = torch.cat([k1, k2], dim=2)
|
| 127 |
-
else:
|
| 128 |
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 129 |
-
|
| 130 |
-
query_states = query_states * _get_llama_4_attn_scale(
|
| 131 |
-
cache_position,
|
| 132 |
-
self.config.rope_parameters.get("llama_4_scaling_beta"),
|
| 133 |
-
self.config.rope_parameters.get("original_max_position_embeddings"),
|
| 134 |
-
).to(query_states.dtype)
|
| 135 |
-
|
| 136 |
-
if past_key_values is not None:
|
| 137 |
-
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 138 |
-
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 139 |
-
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 140 |
-
|
| 141 |
-
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 142 |
-
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 143 |
-
|
| 144 |
-
if self.block_diff_mask is None or q_len != self.block_diff_mask.shape[-2]:
|
| 145 |
-
block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
|
| 146 |
-
else:
|
| 147 |
-
block_mask = self.block_diff_mask
|
| 148 |
-
|
| 149 |
-
attn_output = fused_flex_attention(query_states, key_states, value_states, block_mask=block_mask)
|
| 150 |
-
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
|
| 151 |
-
|
| 152 |
-
attn_output = self.o_proj(attn_output)
|
| 153 |
-
|
| 154 |
-
return attn_output, None
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
class NemotronLabsDiffusionModel(Ministral3PreTrainedModel, GenerationMixin):
|
| 158 |
-
"""
|
| 159 |
-
A single model with:
|
| 160 |
-
- a bidirectional encoder + diffusion‐LM head over A
|
| 161 |
-
- a causal decoder + LM head over B, conditioned on F_A
|
| 162 |
-
"""
|
| 163 |
-
|
| 164 |
-
def __init__(self, config: NemotronLabsDiffusionConfig):
|
| 165 |
-
super().__init__(config)
|
| 166 |
-
|
| 167 |
-
self.mask_token_id = config.mask_token_id
|
| 168 |
-
|
| 169 |
-
diffusion_config = copy.deepcopy(config)
|
| 170 |
-
diffusion_config.diffusion_lm = True
|
| 171 |
-
|
| 172 |
-
if config.dlm_paradigm == 'block_diff':
|
| 173 |
-
diffusion_config.attn_class = NemotronLabsDiffusionFlexAttention
|
| 174 |
-
elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
|
| 175 |
-
diffusion_config.attn_class = Ministral3Attention
|
| 176 |
-
if config.dlm_paradigm == 'autoregressive':
|
| 177 |
-
diffusion_config.diffusion_lm = False
|
| 178 |
-
else:
|
| 179 |
-
raise ValueError(f"Unsupported DLM paradigm: {config.dlm_paradigm}")
|
| 180 |
-
|
| 181 |
-
self.encoder = Ministral3Model(diffusion_config)
|
| 182 |
-
self.diffusion_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 183 |
-
self.vocab_size = config.vocab_size
|
| 184 |
-
|
| 185 |
-
self.post_init()
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
def get_input_embeddings(self):
|
| 189 |
-
return self.encoder.embed_tokens
|
| 190 |
-
|
| 191 |
-
def set_input_embeddings(self, value):
|
| 192 |
-
self.encoder.embed_tokens = value
|
| 193 |
-
|
| 194 |
-
def get_output_embeddings(self):
|
| 195 |
-
return self.diffusion_head
|
| 196 |
-
|
| 197 |
-
def set_output_embeddings(self, new_embeddings):
|
| 198 |
-
self.diffusion_head = new_embeddings
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
def forward_process(self, input_ids, eps=1e-3, block_size=None, loss_mask=None):
|
| 202 |
-
b, l = input_ids.shape
|
| 203 |
-
device = input_ids.device
|
| 204 |
-
|
| 205 |
-
if self.config.dp_varying_mask_ratio:
|
| 206 |
-
# Enable different random seeds for each DP rank during sampling
|
| 207 |
-
import torch.distributed as dist
|
| 208 |
-
dp_rank = 0
|
| 209 |
-
if dist.is_initialized():
|
| 210 |
-
try:
|
| 211 |
-
dp_rank = dist.get_rank()
|
| 212 |
-
except Exception:
|
| 213 |
-
dp_rank = 0
|
| 214 |
-
# Use a local generator to avoid affecting global RNG state
|
| 215 |
-
generator = torch.Generator(device=device)
|
| 216 |
-
generator.manual_seed(torch.seed() + dp_rank)
|
| 217 |
-
else:
|
| 218 |
-
generator = None
|
| 219 |
-
|
| 220 |
-
t = torch.rand(b, device=device, generator=generator)
|
| 221 |
-
|
| 222 |
-
p_mask = (1 - eps) * t + eps # shape: (b,)
|
| 223 |
-
p_mask = p_mask[:, None].expand(-1, l) # shape: (b, l)
|
| 224 |
-
|
| 225 |
-
masked_indices = torch.rand((b, l), device=device) < p_mask
|
| 226 |
-
|
| 227 |
-
if loss_mask is not None:
|
| 228 |
-
masked_indices[loss_mask == 0] = 0
|
| 229 |
-
|
| 230 |
-
noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
|
| 231 |
-
|
| 232 |
-
return noisy_batch, masked_indices, p_mask
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
def forward(
|
| 236 |
-
self,
|
| 237 |
-
input_ids: torch.LongTensor,
|
| 238 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 239 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 240 |
-
labels: Optional[torch.LongTensor] = None,
|
| 241 |
-
split_len: Optional[int] = None,
|
| 242 |
-
past_key_values: Optional[Cache] = None,
|
| 243 |
-
block_size: Optional[int] = None,
|
| 244 |
-
eps: float = 1e-3,
|
| 245 |
-
is_teacher: bool = False,
|
| 246 |
-
masked_indices: Optional[torch.Tensor] = None,
|
| 247 |
-
p_mask: Optional[torch.Tensor] = None,
|
| 248 |
-
teacher_logits: Optional[torch.Tensor] = None,
|
| 249 |
-
masked_indices_teacher: Optional[torch.Tensor] = None,
|
| 250 |
-
loss_mask: Optional[torch.Tensor] = None,
|
| 251 |
-
ce_loss_weight: float = 1.0,
|
| 252 |
-
output_last_hidden_states_only: bool = False,
|
| 253 |
-
skip_loss: bool = False,
|
| 254 |
-
**kwargs,
|
| 255 |
-
) -> CausalLMOutputWithPast:
|
| 256 |
-
|
| 257 |
-
batch_size, seq_len = input_ids.shape
|
| 258 |
-
|
| 259 |
-
if self.config.dlm_paradigm == 'block_diff':
|
| 260 |
-
if labels is not None and block_size is None:
|
| 261 |
-
block_size = self.config.block_size
|
| 262 |
-
elif self.config.dlm_paradigm not in ('bidirectional', 'autoregressive'):
|
| 263 |
-
raise ValueError(f"Unknown dLM paradigm: {self.config.dlm_paradigm}")
|
| 264 |
-
|
| 265 |
-
if labels is not None and self.config.dlm_paradigm != 'autoregressive':
|
| 266 |
-
if masked_indices is not None:
|
| 267 |
-
# assert p_mask is not None
|
| 268 |
-
|
| 269 |
-
if loss_mask is not None:
|
| 270 |
-
masked_indices[loss_mask == 0] = 0
|
| 271 |
-
|
| 272 |
-
noisy_inputs = torch.where(masked_indices, self.mask_token_id, input_ids)
|
| 273 |
-
|
| 274 |
-
else:
|
| 275 |
-
noisy_inputs, masked_indices, p_mask = self.forward_process(input_ids, eps=eps, block_size=block_size, loss_mask=loss_mask)
|
| 276 |
-
|
| 277 |
-
else:
|
| 278 |
-
noisy_inputs = input_ids
|
| 279 |
-
masked_indices = None
|
| 280 |
-
p_mask = None
|
| 281 |
-
|
| 282 |
-
input_ids_len = noisy_inputs.shape[1]
|
| 283 |
-
if labels is not None and self.config.dlm_paradigm == 'block_diff':
|
| 284 |
-
if position_ids is None:
|
| 285 |
-
position_ids = torch.arange(input_ids_len, device=noisy_inputs.device).unsqueeze(0)
|
| 286 |
-
noisy_inputs = torch.cat([noisy_inputs, input_ids], dim=1)
|
| 287 |
-
|
| 288 |
-
enc_out = self.encoder(
|
| 289 |
-
past_key_values=past_key_values,
|
| 290 |
-
input_ids=noisy_inputs,
|
| 291 |
-
attention_mask=attention_mask,
|
| 292 |
-
position_ids=position_ids,
|
| 293 |
-
is_training=(labels is not None),
|
| 294 |
-
**kwargs,
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
if output_last_hidden_states_only:
|
| 298 |
-
return BaseModelOutput(last_hidden_state=enc_out.last_hidden_state)
|
| 299 |
-
|
| 300 |
-
logits = self.diffusion_head(enc_out.last_hidden_state) # (batch, len_B, vocab)
|
| 301 |
-
causal_logits = None
|
| 302 |
-
|
| 303 |
-
if labels is not None and self.config.dlm_paradigm == 'block_diff':
|
| 304 |
-
causal_logits = logits[:, input_ids_len:]
|
| 305 |
-
logits = logits[:, :input_ids_len]
|
| 306 |
-
|
| 307 |
-
loss = None
|
| 308 |
-
if labels is not None and not skip_loss:
|
| 309 |
-
if self.config.dlm_paradigm == 'autoregressive':
|
| 310 |
-
shift_logits = logits[..., :-1, :].contiguous()
|
| 311 |
-
shift_labels = labels[..., 1:].contiguous()
|
| 312 |
-
|
| 313 |
-
if loss_mask is None:
|
| 314 |
-
loss_fct = CrossEntropyLoss()
|
| 315 |
-
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
| 316 |
-
shift_labels = shift_labels.view(-1)
|
| 317 |
-
loss = loss_fct(shift_logits, shift_labels)
|
| 318 |
-
|
| 319 |
-
else:
|
| 320 |
-
loss_mask = loss_mask[..., 1:].contiguous()
|
| 321 |
-
|
| 322 |
-
loss_fct = CrossEntropyLoss(reduction='none')
|
| 323 |
-
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
| 324 |
-
shift_labels = shift_labels.view(-1)
|
| 325 |
-
shift_labels = shift_labels.to(shift_logits.device)
|
| 326 |
-
|
| 327 |
-
token_losses = loss_fct(shift_logits, shift_labels)
|
| 328 |
-
|
| 329 |
-
flat_loss_mask = loss_mask.reshape(-1)
|
| 330 |
-
loss = token_losses[flat_loss_mask == 1].sum() / flat_loss_mask.sum()
|
| 331 |
-
|
| 332 |
-
else:
|
| 333 |
-
# LLaDA-style diffusion loss on masked positions.
|
| 334 |
-
# Token-wise cross entropy loss on masked positions.
|
| 335 |
-
token_loss = torch.nn.functional.cross_entropy(
|
| 336 |
-
logits[masked_indices],
|
| 337 |
-
labels[masked_indices],
|
| 338 |
-
reduction='none'
|
| 339 |
-
) / p_mask[masked_indices]
|
| 340 |
-
|
| 341 |
-
num_mask_tokens = masked_indices.sum()
|
| 342 |
-
|
| 343 |
-
# global_loss_avg=True: loss is reduced externally by global token count.
|
| 344 |
-
loss = token_loss.sum()
|
| 345 |
-
|
| 346 |
-
if self.config.dlm_loss_weight is not None:
|
| 347 |
-
loss = self.config.dlm_loss_weight * loss
|
| 348 |
-
|
| 349 |
-
if self.config.dlm_paradigm == 'block_diff':
|
| 350 |
-
# AR-side loss for block-diffusion paradigm.
|
| 351 |
-
causal_logits = causal_logits[..., :-1, :].contiguous()
|
| 352 |
-
causal_logits = causal_logits.view(-1, causal_logits.size(-1))
|
| 353 |
-
causal_labels = labels[..., 1:].contiguous().view(-1)
|
| 354 |
-
|
| 355 |
-
loss_fct = CrossEntropyLoss(reduction='sum')
|
| 356 |
-
ar_loss = loss_fct(causal_logits, causal_labels)
|
| 357 |
-
|
| 358 |
-
self.loss_diffusion = loss.detach().item() / num_mask_tokens
|
| 359 |
-
self.loss_ar = ar_loss.detach().item() / seq_len
|
| 360 |
-
|
| 361 |
-
loss = loss + self.config.ar_loss_weight * ar_loss
|
| 362 |
-
|
| 363 |
-
# global_loss_avg=True: return (sum_loss, token_count) for external mean.
|
| 364 |
-
if self.config.dlm_paradigm == 'block_diff':
|
| 365 |
-
loss = (loss, num_mask_tokens + int(self.config.ar_loss_weight * seq_len))
|
| 366 |
-
else:
|
| 367 |
-
loss = (loss, num_mask_tokens)
|
| 368 |
-
|
| 369 |
-
return NemotronLabsDiffusionOutputWithPast(
|
| 370 |
-
loss=loss if not is_teacher else logits,
|
| 371 |
-
logits=logits,
|
| 372 |
-
causal_logits=causal_logits,
|
| 373 |
-
past_key_values=enc_out.past_key_values,
|
| 374 |
-
hidden_states=None,
|
| 375 |
-
attentions=None,
|
| 376 |
-
)
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
@torch.no_grad()
|
| 380 |
-
def generate(
|
| 381 |
-
self,
|
| 382 |
-
prompt_ids: torch.Tensor,
|
| 383 |
-
max_new_tokens: int,
|
| 384 |
-
block_length: int,
|
| 385 |
-
threshold: Optional[float] = None,
|
| 386 |
-
causal_context: bool = True,
|
| 387 |
-
temperature: float = 0.0,
|
| 388 |
-
eos_token_id: Optional[int] = None,
|
| 389 |
-
max_thinking_tokens: Optional[int] = None,
|
| 390 |
-
end_think_token_id: Optional[int] = None,
|
| 391 |
-
):
|
| 392 |
-
"""Block-wise diffusion decoding with prefix-cached KV (LLaDA-style).
|
| 393 |
-
|
| 394 |
-
Each block: append `block_length` mask tokens, then iteratively unmask
|
| 395 |
-
by confidence top-k (with optional threshold). When `causal_context`,
|
| 396 |
-
the KV cache and the next-block seed are produced via a causal forward
|
| 397 |
-
between blocks (flipping `self_attn.diffusion_lm`), matching the AR
|
| 398 |
-
objective at block boundaries.
|
| 399 |
-
|
| 400 |
-
Returns (output_ids, nfe) — output_ids includes the prompt.
|
| 401 |
-
"""
|
| 402 |
-
if eos_token_id is None:
|
| 403 |
-
eos_token_id = getattr(self.config, "eos_token_id", None)
|
| 404 |
-
mask_id = self.mask_token_id
|
| 405 |
-
|
| 406 |
-
x_accum = prompt_ids.clone()
|
| 407 |
-
B = prompt_ids.shape[0]
|
| 408 |
-
|
| 409 |
-
assert max_new_tokens % block_length == 0
|
| 410 |
-
num_blocks = max_new_tokens // block_length
|
| 411 |
-
# one denoising step per generated token (matches legacy chat_utils call)
|
| 412 |
-
steps_per_block = block_length
|
| 413 |
-
|
| 414 |
-
nfe = 0
|
| 415 |
-
|
| 416 |
-
def _set_diffusion_lm(val: bool):
|
| 417 |
-
for layer in self.encoder.layers:
|
| 418 |
-
if hasattr(layer.self_attn, "diffusion_lm"):
|
| 419 |
-
layer.self_attn.diffusion_lm = val
|
| 420 |
-
|
| 421 |
-
# Initial causal prefill produces the KV cache and the next-block seed.
|
| 422 |
-
if causal_context:
|
| 423 |
-
_set_diffusion_lm(False)
|
| 424 |
-
output = self(prompt_ids, use_cache=True, use_causal_mask=causal_context)
|
| 425 |
-
past_key_values = output.past_key_values
|
| 426 |
-
if causal_context:
|
| 427 |
-
_set_diffusion_lm(True)
|
| 428 |
-
|
| 429 |
-
next_token = None
|
| 430 |
-
if causal_context:
|
| 431 |
-
last_logit = output.logits[:, -1, :]
|
| 432 |
-
if temperature > 0:
|
| 433 |
-
next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), num_samples=1)
|
| 434 |
-
else:
|
| 435 |
-
next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
|
| 436 |
-
|
| 437 |
-
for num_block in range(num_blocks):
|
| 438 |
-
mask_block = torch.full(
|
| 439 |
-
(B, block_length), mask_id, dtype=prompt_ids.dtype, device=prompt_ids.device,
|
| 440 |
-
)
|
| 441 |
-
if causal_context:
|
| 442 |
-
mask_block[:, 0] = next_token[:, 0]
|
| 443 |
-
|
| 444 |
-
x_accum = torch.cat([x_accum, mask_block], dim=1)
|
| 445 |
-
block_start = prompt_ids.size(1) + num_block * block_length
|
| 446 |
-
block_slice = slice(block_start, block_start + block_length)
|
| 447 |
-
|
| 448 |
-
# Thinking-budget enforcement: if we've passed max_thinking_tokens
|
| 449 |
-
# without an end-think marker, inject one into this block.
|
| 450 |
-
if end_think_token_id is not None and max_thinking_tokens is not None:
|
| 451 |
-
tokens_before = num_block * block_length
|
| 452 |
-
tokens_after = tokens_before + block_length
|
| 453 |
-
if tokens_after > max_thinking_tokens:
|
| 454 |
-
gen_so_far = x_accum[:, prompt_ids.size(1):block_start]
|
| 455 |
-
has_end_think = (
|
| 456 |
-
(gen_so_far == end_think_token_id).any(dim=1)
|
| 457 |
-
if gen_so_far.size(1) > 0
|
| 458 |
-
else torch.zeros(B, dtype=torch.bool, device=prompt_ids.device)
|
| 459 |
-
)
|
| 460 |
-
if not has_end_think.all():
|
| 461 |
-
offset = max(0, max_thinking_tokens - tokens_before)
|
| 462 |
-
inject_pos = block_start + offset
|
| 463 |
-
for b in range(B):
|
| 464 |
-
if not has_end_think[b]:
|
| 465 |
-
x_accum[b, inject_pos] = end_think_token_id
|
| 466 |
-
|
| 467 |
-
mask_block_idx0 = x_accum[:, block_slice] == mask_id
|
| 468 |
-
num_transfer_tokens = _get_num_transfer_tokens(mask_block_idx0, steps_per_block)
|
| 469 |
-
|
| 470 |
-
# Denoise the current block by repeated confidence-based unmasking.
|
| 471 |
-
for i in range(steps_per_block):
|
| 472 |
-
mask_block_idx = x_accum[:, block_slice] == mask_id
|
| 473 |
-
if mask_block_idx.sum() == 0:
|
| 474 |
-
break
|
| 475 |
-
|
| 476 |
-
nfe += 1
|
| 477 |
-
logits_block = self(
|
| 478 |
-
x_accum[:, block_slice],
|
| 479 |
-
past_key_values=past_key_values,
|
| 480 |
-
use_cache=False,
|
| 481 |
-
).logits
|
| 482 |
-
|
| 483 |
-
x0, transfer_idx = _get_transfer_index(
|
| 484 |
-
logits_block, temperature, mask_block_idx, x_accum[:, block_slice],
|
| 485 |
-
num_transfer_tokens=num_transfer_tokens[:, i], threshold=threshold,
|
| 486 |
-
)
|
| 487 |
-
cur = x_accum[:, block_slice].clone()
|
| 488 |
-
cur[transfer_idx] = x0[transfer_idx]
|
| 489 |
-
x_accum[:, block_slice] = cur
|
| 490 |
-
|
| 491 |
-
if eos_token_id is not None:
|
| 492 |
-
block_tokens = x_accum[:, block_slice]
|
| 493 |
-
eos_mask = block_tokens == eos_token_id
|
| 494 |
-
if eos_mask.any(dim=1).any():
|
| 495 |
-
after_eos = eos_mask.cumsum(dim=1).bool()
|
| 496 |
-
mask_before = (block_tokens == mask_id) & ~after_eos
|
| 497 |
-
if (eos_mask.any(dim=1) & ~mask_before.any(dim=1)).any():
|
| 498 |
-
break
|
| 499 |
-
|
| 500 |
-
# Post-block: causal forward over the block to update the KV cache
|
| 501 |
-
# and (when causal_context) sample the seed for the next block.
|
| 502 |
-
if causal_context:
|
| 503 |
-
_set_diffusion_lm(False)
|
| 504 |
-
output = self(
|
| 505 |
-
x_accum[:, block_slice],
|
| 506 |
-
past_key_values=past_key_values,
|
| 507 |
-
use_cache=True,
|
| 508 |
-
use_causal_mask=causal_context,
|
| 509 |
-
)
|
| 510 |
-
past_key_values = output.past_key_values
|
| 511 |
-
nfe += 1
|
| 512 |
-
|
| 513 |
-
if causal_context:
|
| 514 |
-
_set_diffusion_lm(True)
|
| 515 |
-
last_logit = output.logits[:, -1, :]
|
| 516 |
-
if temperature > 0:
|
| 517 |
-
next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), num_samples=1)
|
| 518 |
-
else:
|
| 519 |
-
next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
|
| 520 |
-
|
| 521 |
-
if eos_token_id is not None:
|
| 522 |
-
gen_so_far = x_accum[:, prompt_ids.size(1):]
|
| 523 |
-
is_eos = gen_so_far == eos_token_id
|
| 524 |
-
if is_eos.any(dim=1).all():
|
| 525 |
-
first_eos = is_eos.to(torch.int64).argmax(dim=1)
|
| 526 |
-
max_eos = first_eos.max().item()
|
| 527 |
-
return x_accum[:, : prompt_ids.size(1) + max_eos + 1], nfe
|
| 528 |
-
|
| 529 |
-
return x_accum, nfe
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
@torch.no_grad()
|
| 534 |
-
def ar_generate(
|
| 535 |
-
self,
|
| 536 |
-
prompt_ids: torch.Tensor,
|
| 537 |
-
max_new_tokens: int = 128,
|
| 538 |
-
temperature: float = 0.0,
|
| 539 |
-
eos_token_id: Optional[int] = None,
|
| 540 |
-
max_thinking_tokens: Optional[int] = None,
|
| 541 |
-
end_think_token_id: Optional[int] = None,
|
| 542 |
-
) -> tuple:
|
| 543 |
-
"""Autoregressive generation calling the encoder directly (injected by build_hf_tidar_repo).
|
| 544 |
-
|
| 545 |
-
Bypasses NemotronLabsDiffusionModel.forward() to avoid diffusion-specific
|
| 546 |
-
code paths. Calls self.encoder (Ministral3Model) with explicit cache_position,
|
| 547 |
-
position_ids, and use_cache so the KV cache and causal masking behave
|
| 548 |
-
identically to MistralForCausalLM / vLLM.
|
| 549 |
-
|
| 550 |
-
Returns:
|
| 551 |
-
(output_ids, nfe) where output_ids includes the prompt.
|
| 552 |
-
"""
|
| 553 |
-
for layer in self.encoder.layers:
|
| 554 |
-
if hasattr(layer.self_attn, 'diffusion_lm'):
|
| 555 |
-
layer.self_attn.diffusion_lm = False
|
| 556 |
-
|
| 557 |
-
if eos_token_id is None:
|
| 558 |
-
eos_token_id = getattr(self.config, 'eos_token_id', None)
|
| 559 |
-
|
| 560 |
-
device = prompt_ids.device
|
| 561 |
-
batch_size, prompt_len = prompt_ids.shape
|
| 562 |
-
|
| 563 |
-
past_key_values = DynamicCache()
|
| 564 |
-
cache_position = torch.arange(prompt_len, device=device)
|
| 565 |
-
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
|
| 566 |
-
|
| 567 |
-
enc_out = self.encoder(
|
| 568 |
-
input_ids=prompt_ids,
|
| 569 |
-
position_ids=position_ids,
|
| 570 |
-
past_key_values=past_key_values,
|
| 571 |
-
use_cache=True,
|
| 572 |
-
cache_position=cache_position,
|
| 573 |
-
)
|
| 574 |
-
past_key_values = enc_out.past_key_values
|
| 575 |
-
next_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
|
| 576 |
-
|
| 577 |
-
generated_tokens = []
|
| 578 |
-
nfe = 0
|
| 579 |
-
|
| 580 |
-
for step in range(max_new_tokens):
|
| 581 |
-
nfe += 1
|
| 582 |
-
|
| 583 |
-
if temperature > 0:
|
| 584 |
-
probs = torch.softmax(next_logit / temperature, dim=-1)
|
| 585 |
-
next_token = torch.multinomial(probs, num_samples=1)
|
| 586 |
-
else:
|
| 587 |
-
next_token = torch.argmax(next_logit, dim=-1, keepdim=True)
|
| 588 |
-
|
| 589 |
-
# ---- thinking budget enforcement ----
|
| 590 |
-
if end_think_token_id is not None and max_thinking_tokens is not None:
|
| 591 |
-
if step >= max_thinking_tokens:
|
| 592 |
-
if generated_tokens:
|
| 593 |
-
gen_tensor = torch.cat(generated_tokens, dim=1)
|
| 594 |
-
has_end_think = (gen_tensor == end_think_token_id).any(dim=1)
|
| 595 |
-
else:
|
| 596 |
-
has_end_think = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
| 597 |
-
for b in range(batch_size):
|
| 598 |
-
if not has_end_think[b]:
|
| 599 |
-
next_token[b] = end_think_token_id
|
| 600 |
-
|
| 601 |
-
generated_tokens.append(next_token)
|
| 602 |
-
|
| 603 |
-
if eos_token_id is not None and (next_token == eos_token_id).all():
|
| 604 |
-
break
|
| 605 |
-
|
| 606 |
-
if step < max_new_tokens - 1:
|
| 607 |
-
cur_pos = prompt_len + step
|
| 608 |
-
step_cache_pos = torch.tensor([cur_pos], device=device)
|
| 609 |
-
step_pos_ids = step_cache_pos.unsqueeze(0).expand(batch_size, -1)
|
| 610 |
-
|
| 611 |
-
enc_out = self.encoder(
|
| 612 |
-
input_ids=next_token,
|
| 613 |
-
position_ids=step_pos_ids,
|
| 614 |
-
past_key_values=past_key_values,
|
| 615 |
-
use_cache=True,
|
| 616 |
-
cache_position=step_cache_pos,
|
| 617 |
-
)
|
| 618 |
-
past_key_values = enc_out.past_key_values
|
| 619 |
-
next_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
|
| 620 |
-
|
| 621 |
-
all_generated = torch.cat(generated_tokens, dim=1)
|
| 622 |
-
output_ids = torch.cat([prompt_ids, all_generated], dim=1)
|
| 623 |
-
return output_ids, nfe
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
@torch.no_grad()
|
| 627 |
-
def linear_spec_generate(
|
| 628 |
-
self,
|
| 629 |
-
prompt_ids: torch.Tensor,
|
| 630 |
-
max_new_tokens: int = 128,
|
| 631 |
-
block_length: int = 32,
|
| 632 |
-
temperature: float = 0.0,
|
| 633 |
-
mask_token_id: Optional[int] = None,
|
| 634 |
-
eos_token_id: Optional[int] = None,
|
| 635 |
-
max_thinking_tokens: Optional[int] = None,
|
| 636 |
-
end_think_token_id: Optional[int] = None,
|
| 637 |
-
threshold: float = 0.0,
|
| 638 |
-
):
|
| 639 |
-
"""Linear speculative decoding: diffusion draft + AR verify.
|
| 640 |
-
|
| 641 |
-
Each iteration: (1) draft the next block under bidirectional attention,
|
| 642 |
-
(2) verify the drafted block under causal attention, accept the longest
|
| 643 |
-
prefix where draft matches AR + one bonus token, advance the KV cache.
|
| 644 |
-
|
| 645 |
-
LoRA-aware: when a PEFT adapter is attached to the model (e.g.
|
| 646 |
-
``linear_spec_lora``), it is toggled ON for the bidirectional draft
|
| 647 |
-
phase and OFF for the causal prefill / verify phases — so the adapter
|
| 648 |
-
only specializes the diffusion-mode forward and AR semantics are
|
| 649 |
-
preserved. With no adapter loaded the calls are no-ops.
|
| 650 |
-
|
| 651 |
-
Returns ``(output_ids, nfe)`` — ``output_ids`` includes the prompt.
|
| 652 |
-
"""
|
| 653 |
-
if prompt_ids.shape[0] != 1:
|
| 654 |
-
raise ValueError("Linear speculative decoding requires batch_size == 1")
|
| 655 |
-
|
| 656 |
-
token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
|
| 657 |
-
if eos_token_id is None:
|
| 658 |
-
eos_token_id = getattr(self.config, "eos_token_id", None)
|
| 659 |
-
|
| 660 |
-
device = prompt_ids.device
|
| 661 |
-
|
| 662 |
-
def _set_diffusion_lm(val: bool):
|
| 663 |
-
for layer in self.encoder.layers:
|
| 664 |
-
if hasattr(layer.self_attn, "diffusion_lm"):
|
| 665 |
-
layer.self_attn.diffusion_lm = val
|
| 666 |
-
|
| 667 |
-
def _toggle_adapters(enable: bool):
|
| 668 |
-
# No-op when no PEFT/LoRA modules are attached.
|
| 669 |
-
for module in self.modules():
|
| 670 |
-
if hasattr(module, "_disable_adapters"):
|
| 671 |
-
module._disable_adapters = not enable
|
| 672 |
-
|
| 673 |
-
# Prefill (causal, LoRA OFF).
|
| 674 |
-
_set_diffusion_lm(False)
|
| 675 |
-
_toggle_adapters(False)
|
| 676 |
-
enc_out = self.encoder(
|
| 677 |
-
input_ids=prompt_ids,
|
| 678 |
-
past_key_values=DynamicCache(),
|
| 679 |
-
use_cache=True,
|
| 680 |
-
use_causal_mask=True,
|
| 681 |
-
)
|
| 682 |
-
past_key_values = enc_out.past_key_values
|
| 683 |
-
last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
|
| 684 |
-
nfe = 1
|
| 685 |
-
|
| 686 |
-
if temperature > 0:
|
| 687 |
-
next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), num_samples=1)
|
| 688 |
-
else:
|
| 689 |
-
next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
|
| 690 |
-
|
| 691 |
-
if eos_token_id is not None and next_token.item() == eos_token_id:
|
| 692 |
-
return torch.cat([prompt_ids, next_token], dim=1), nfe
|
| 693 |
-
|
| 694 |
-
generated = [next_token]
|
| 695 |
-
total_gen = 1
|
| 696 |
-
|
| 697 |
-
while total_gen < max_new_tokens:
|
| 698 |
-
cache_len = past_key_values.get_seq_length()
|
| 699 |
-
|
| 700 |
-
block = torch.full((1, block_length), token_mask_id, dtype=torch.long, device=device)
|
| 701 |
-
block[0, 0] = next_token.item()
|
| 702 |
-
|
| 703 |
-
# Draft phase (bidirectional, LoRA ON) — iterate at threshold>0 so
|
| 704 |
-
# that even low-confidence blocks make progress.
|
| 705 |
-
_set_diffusion_lm(True)
|
| 706 |
-
_toggle_adapters(True)
|
| 707 |
-
while True:
|
| 708 |
-
is_mask = block == token_mask_id
|
| 709 |
-
if not is_mask.any():
|
| 710 |
-
break
|
| 711 |
-
|
| 712 |
-
enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=False)
|
| 713 |
-
nfe += 1
|
| 714 |
-
|
| 715 |
-
draft_logits = self.diffusion_head(enc_out.last_hidden_state)
|
| 716 |
-
# LLaDA: logit[i] directly predicts position i — no shift needed.
|
| 717 |
-
|
| 718 |
-
if temperature > 0:
|
| 719 |
-
draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
|
| 720 |
-
draft_tokens = torch.multinomial(
|
| 721 |
-
draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1
|
| 722 |
-
).view(1, block_length)
|
| 723 |
-
else:
|
| 724 |
-
draft_tokens = draft_logits.argmax(dim=-1)
|
| 725 |
-
draft_probs = torch.softmax(draft_logits, dim=-1)
|
| 726 |
-
|
| 727 |
-
if threshold > 0:
|
| 728 |
-
draft_conf = torch.gather(draft_probs, -1, draft_tokens.unsqueeze(-1)).squeeze(-1)
|
| 729 |
-
draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
|
| 730 |
-
unmask = draft_conf >= threshold
|
| 731 |
-
# Force progress even when every masked position is below threshold.
|
| 732 |
-
if not unmask.any():
|
| 733 |
-
best_idx = draft_conf.view(-1).argmax()
|
| 734 |
-
unmask = torch.zeros_like(is_mask, dtype=torch.bool)
|
| 735 |
-
unmask.view(-1)[best_idx] = True
|
| 736 |
-
block[unmask] = draft_tokens[unmask]
|
| 737 |
-
else:
|
| 738 |
-
block[is_mask] = draft_tokens[is_mask]
|
| 739 |
-
break
|
| 740 |
-
|
| 741 |
-
# Verify phase (causal, LoRA OFF).
|
| 742 |
-
_set_diffusion_lm(False)
|
| 743 |
-
_toggle_adapters(False)
|
| 744 |
-
enc_out = self.encoder(
|
| 745 |
-
input_ids=block,
|
| 746 |
-
past_key_values=past_key_values,
|
| 747 |
-
use_cache=True,
|
| 748 |
-
use_causal_mask=True,
|
| 749 |
-
)
|
| 750 |
-
past_key_values = enc_out.past_key_values
|
| 751 |
-
nfe += 1
|
| 752 |
-
|
| 753 |
-
verify_logits = self.diffusion_head(enc_out.last_hidden_state)
|
| 754 |
-
if temperature > 0:
|
| 755 |
-
ar_tokens = torch.multinomial(
|
| 756 |
-
torch.softmax(verify_logits / temperature, dim=-1).view(-1, verify_logits.shape[-1]),
|
| 757 |
-
num_samples=1,
|
| 758 |
-
).view(1, block_length)
|
| 759 |
-
else:
|
| 760 |
-
ar_tokens = verify_logits.argmax(dim=-1)
|
| 761 |
-
|
| 762 |
-
# Accept consecutive matches; AR also gives one bonus token at the tail.
|
| 763 |
-
accepted = 0
|
| 764 |
-
for i in range(block_length - 1):
|
| 765 |
-
if ar_tokens[0, i].item() == block[0, i + 1].item():
|
| 766 |
-
accepted += 1
|
| 767 |
-
else:
|
| 768 |
-
break
|
| 769 |
-
accepted += 1
|
| 770 |
-
|
| 771 |
-
accepted_toks = ar_tokens[:, :accepted]
|
| 772 |
-
generated.append(accepted_toks)
|
| 773 |
-
total_gen += accepted
|
| 774 |
-
|
| 775 |
-
_crop_dynamic_cache(past_key_values, cache_len + accepted)
|
| 776 |
-
next_token = ar_tokens[:, accepted - 1 : accepted]
|
| 777 |
-
|
| 778 |
-
if eos_token_id is not None:
|
| 779 |
-
eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
|
| 780 |
-
if len(eos_pos) > 0:
|
| 781 |
-
first_eos = eos_pos[0].item()
|
| 782 |
-
generated[-1] = accepted_toks[:, : first_eos + 1]
|
| 783 |
-
total_gen = total_gen - accepted + first_eos + 1
|
| 784 |
-
break
|
| 785 |
-
|
| 786 |
-
# Thinking-budget enforcement: force end-think as next seed if budget exhausted.
|
| 787 |
-
if end_think_token_id is not None and max_thinking_tokens is not None:
|
| 788 |
-
if total_gen > max_thinking_tokens:
|
| 789 |
-
all_gen = torch.cat(generated, dim=1)
|
| 790 |
-
if not (all_gen == end_think_token_id).any():
|
| 791 |
-
next_token = torch.tensor([[end_think_token_id]], device=device)
|
| 792 |
-
|
| 793 |
-
if total_gen >= max_new_tokens:
|
| 794 |
-
break
|
| 795 |
-
|
| 796 |
-
all_generated = torch.cat(generated, dim=1)
|
| 797 |
-
output_ids = torch.cat([prompt_ids, all_generated], dim=1)
|
| 798 |
-
return output_ids, nfe
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
# ─── Module-level helpers used by `generate` and `linear_spec_generate` ──
|
| 802 |
-
|
| 803 |
-
def _crop_dynamic_cache(past_key_values: DynamicCache, max_length: int):
|
| 804 |
-
"""Crop a DynamicCache to max_length, compatible with both old and new transformers."""
|
| 805 |
-
if hasattr(past_key_values, 'crop'):
|
| 806 |
-
past_key_values.crop(max_length)
|
| 807 |
-
else:
|
| 808 |
-
for layer_idx in range(len(past_key_values)):
|
| 809 |
-
past_key_values.key_cache[layer_idx] = past_key_values.key_cache[layer_idx][:, :, :max_length]
|
| 810 |
-
past_key_values.value_cache[layer_idx] = past_key_values.value_cache[layer_idx][:, :, :max_length]
|
| 811 |
-
past_key_values._seen_tokens = max_length
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
def _add_gumbel_noise(logits, temperature):
|
| 815 |
-
"""Gumbel-max sampling in float64 (low-precision Gumbel hurts MDM quality)."""
|
| 816 |
-
if temperature == 0:
|
| 817 |
-
return logits
|
| 818 |
-
logits = logits.to(torch.float64)
|
| 819 |
-
noise = torch.rand_like(logits, dtype=torch.float64)
|
| 820 |
-
gumbel_noise = (- torch.log(noise)) ** temperature
|
| 821 |
-
return logits.exp() / gumbel_noise
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
def _get_num_transfer_tokens(mask_index, steps: int):
|
| 825 |
-
"""Even split of masked positions across `steps`, with remainder front-loaded."""
|
| 826 |
-
mask_num = mask_index.sum(dim=1, keepdim=True)
|
| 827 |
-
base = mask_num // steps
|
| 828 |
-
remainder = mask_num % steps
|
| 829 |
-
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
|
| 830 |
-
for i in range(mask_num.size(0)):
|
| 831 |
-
num_transfer_tokens[i, : int(remainder[i])] += 1
|
| 832 |
-
return num_transfer_tokens
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
def _get_transfer_index(logits, temperature, mask_index, x, num_transfer_tokens, threshold=None):
|
| 836 |
-
"""Pick which masked positions to commit this denoising step.
|
| 837 |
-
|
| 838 |
-
Returns (x0, transfer_index): x0 is argmax tokens (clamped to original x at
|
| 839 |
-
non-masked positions); transfer_index is a bool mask over positions to
|
| 840 |
-
finalize, chosen by top-k confidence (and filtered by `threshold` if given).
|
| 841 |
-
"""
|
| 842 |
-
logits_with_noise = _add_gumbel_noise(logits, temperature=temperature)
|
| 843 |
-
x0 = torch.argmax(logits_with_noise, dim=-1)
|
| 844 |
-
|
| 845 |
-
p = F.softmax(logits, dim=-1)
|
| 846 |
-
x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)
|
| 847 |
-
|
| 848 |
-
x0 = torch.where(mask_index, x0, x)
|
| 849 |
-
confidence = torch.where(mask_index, x0_p, -np.inf)
|
| 850 |
-
|
| 851 |
-
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
|
| 852 |
-
if threshold is not None:
|
| 853 |
-
num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
|
| 854 |
-
for j in range(confidence.shape[0]):
|
| 855 |
-
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
|
| 856 |
-
transfer_index[j, select_index] = True
|
| 857 |
-
if threshold is not None:
|
| 858 |
-
for k in range(1, num_transfer_tokens[j]):
|
| 859 |
-
if confidence[j, select_index[k]] < threshold:
|
| 860 |
-
transfer_index[j, select_index[k]] = False
|
| 861 |
-
return x0, transfer_index
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
|
| 865 |
-
"""Return a Bool mask of length len(log_w) with exactly k True."""
|
| 866 |
-
g = -torch.log(-torch.log(torch.rand_like(log_w) + 1e-9) + 1e-9)
|
| 867 |
-
topk = torch.topk(log_w + g, k).indices
|
| 868 |
-
mask = torch.zeros_like(log_w, dtype=torch.bool)
|
| 869 |
-
mask[topk] = True
|
| 870 |
-
return mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|