Text Generation
Transformers
Safetensors
PyTorch
nemotron_labs_diffusion
feature-extraction
nvidia
conversational
custom_code
Instructions to use nvidia/Nemotron-Labs-Diffusion-8B-Base with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nvidia/Nemotron-Labs-Diffusion-8B-Base with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="nvidia/Nemotron-Labs-Diffusion-8B-Base", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nvidia/Nemotron-Labs-Diffusion-8B-Base", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use nvidia/Nemotron-Labs-Diffusion-8B-Base with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "nvidia/Nemotron-Labs-Diffusion-8B-Base" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/Nemotron-Labs-Diffusion-8B-Base", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/nvidia/Nemotron-Labs-Diffusion-8B-Base
- SGLang
How to use nvidia/Nemotron-Labs-Diffusion-8B-Base with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "nvidia/Nemotron-Labs-Diffusion-8B-Base" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/Nemotron-Labs-Diffusion-8B-Base", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "nvidia/Nemotron-Labs-Diffusion-8B-Base" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/Nemotron-Labs-Diffusion-8B-Base", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use nvidia/Nemotron-Labs-Diffusion-8B-Base with Docker Model Runner:
docker model run hf.co/nvidia/Nemotron-Labs-Diffusion-8B-Base
Commit ·
cf02602
0
Parent(s):
Initial release of Nemotron-Labs-Diffusion-8B-Base
Browse filesCo-authored-by: abhgarg <abhgarg@users.noreply.huggingface.co>
Co-authored-by: trias702 <trias702@users.noreply.huggingface.co>
Co-authored-by: trias702 <trias702@users.noreply.huggingface.co>
Co-authored-by: pmolchanov <pmolchanov@users.noreply.huggingface.co>
- .gitattributes +41 -0
- README.md +160 -0
- assets/demo.gif +3 -0
- assets/demo.mp4 +3 -0
- assets/result_acc.png +3 -0
- assets/result_efficiency.png +3 -0
- assets/teaser.png +3 -0
- chat_template.jinja +7 -0
- config.json +49 -0
- configuration_nemotron_labs_diffusion.py +186 -0
- generation_config.json +7 -0
- model.safetensors +3 -0
- model_cards/bias.md +4 -0
- model_cards/explainability.md +13 -0
- model_cards/privacy.md +11 -0
- model_cards/safety.md +6 -0
- modeling_ministral.py +459 -0
- modeling_nemotron_labs_diffusion.py +870 -0
- special_tokens_map.json +23 -0
- tokenizer.json +3 -0
- tokenizer_config.json +0 -0
.gitattributes
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/demo.gif filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/demo.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/result_acc.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/result_efficiency.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
assets/teaser.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
license: other
|
| 4 |
+
license_name: nvidia-nemotron-open-model-license
|
| 5 |
+
license_link: >-
|
| 6 |
+
https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-nemotron-open-model-license/
|
| 7 |
+
pipeline_tag: text-generation
|
| 8 |
+
tags:
|
| 9 |
+
- nvidia
|
| 10 |
+
- pytorch
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# Nemotron-Labs-Diffusion-8B-Base
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
<div align="center" style="line-height: 1;">
|
| 17 |
+
<a href="https://d1qx31qr3h6wln.cloudfront.net/publications/Nemotron_Diffusion_Tech_Report_v1.pdf?VersionId=db8_EMO8B.vmU26.jr7Le9pN3MqcUDNL" target="_blank" style="margin: 2px;">
|
| 18 |
+
<img alt="Chat" src="https://img.shields.io/badge/📝Paper-Read Now!-536af5?color=76B900&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
|
| 19 |
+
</a>
|
| 20 |
+
<a href="https://huggingface.co/collections/nvidia/nemotron-labs-diffusion" target="_blank" style="margin: 2px;">
|
| 21 |
+
<img alt="Nemotron-Labs-Diffusion Model Family" src="https://img.shields.io/badge/%F0%9F%A4%97-Nemotron--Labs--Diffusion_Model_Family-76B900" style="display: inline-block; vertical-align: middle;"/>
|
| 22 |
+
</a>
|
| 23 |
+
<a href="https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-nemotron-open-model-license/" style="margin: 2px;">
|
| 24 |
+
<img alt="License" src="https://img.shields.io/badge/License-NVIDIA Open Model License-f5de53?&color=f5de53" style="display: inline-block; vertical-align: middle;"/>
|
| 25 |
+
</a>
|
| 26 |
+
</div>
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
[](./assets/demo.mp4)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
## Model Overview
|
| 33 |
+
|
| 34 |
+
Nemotron-Labs-Diffusion is a tri-mode language model that supports both AR decoding and diffusion-based parallel decoding by simply switching the attention pattern of the same model during inference. The synergy between these two modes enables a third mode, called self-speculation: the same model performs diffusion-based parallel drafting and AR verification with shared KV cache, achieving high acceptance lengths and decoding efficiency. The seamless mode switching by simply changing attention patterns enables high efficiency at different concurrency levels in varying deployment scenarios with one single model.
|
| 35 |
+
|
| 36 |
+
<div align="center">
|
| 37 |
+
<img src="./assets/teaser.png" alt="An illustration of Tri-Mode LMs" width="500">
|
| 38 |
+
</div>
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
## Highlights
|
| 42 |
+
|
| 43 |
+
- SOTA 3B, 8B, 14B dense LM family (base, instruct, and vision-language variants) supporting AR, diffusion, and self-speculation with the focus on decode efficiency.
|
| 44 |
+
- Generation moved from a memory-bound regime toward a compute-bound regime. Model weights are loaded once and reused to compute multiple tokens during generation.
|
| 45 |
+
- Self-speculation uses diffusion for drafting and AR for verification, providing a stronger alternative to MTP approaches:
|
| 46 |
+
* 3x higher acceptance length and 2.2x speed-up vs. Qwen3-8B-Eagle3 in SGLang.
|
| 47 |
+
* 5.9× tokens per forward over Qwen3-8B (no MTP) with the same accuracy.
|
| 48 |
+
- Real-device speed-up across platforms:
|
| 49 |
+
* DGX Spark (8B, concurrency 1): 2.7x faster with 112 tok/sec vs. 41.8 tok/sec AR using w4a16.
|
| 50 |
+
* GB200 (8B, concurrency 1): 3.3x faster with 850 tok/sec vs. 253 tok/sec AR and 360 tok/sec Eagle3. Custom CUDA kernels boost to 1015 tok/sec (4x).
|
| 51 |
+
- Diffusion speedup-of-light analysis shows that throughput can be further doubled (vs. current best) for a single user with better sampling - future research.
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
<div align="center">
|
| 55 |
+
<img src="./assets/result_acc.png" alt="Efficiency Results" width="800">
|
| 56 |
+
</div>
|
| 57 |
+
|
| 58 |
+
<div align="center">
|
| 59 |
+
<img src="./assets/result_efficiency.png" alt="Acc Results" width="800">
|
| 60 |
+
</div>
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
## License/Terms of Use
|
| 64 |
+
|
| 65 |
+
Use of this model is governed by the [NVIDIA Nemotron Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-nemotron-open-model-license/).
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
## Environment
|
| 69 |
+
|
| 70 |
+
```bash
|
| 71 |
+
transformers>=5.0.0
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
## Chat with Our Model
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
```
|
| 78 |
+
from transformers import AutoModel, AutoTokenizer
|
| 79 |
+
import torch
|
| 80 |
+
|
| 81 |
+
repo_name = "nvidia/Nemotron-Labs-Diffusion-8B-Base"
|
| 82 |
+
|
| 83 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
|
| 84 |
+
model = AutoModel.from_pretrained(repo_name, trust_remote_code=True)
|
| 85 |
+
model = model.cuda().to(torch.bfloat16)
|
| 86 |
+
|
| 87 |
+
history = []
|
| 88 |
+
|
| 89 |
+
user_input = input("User: ").strip()
|
| 90 |
+
history.append({"role": "user", "content": user_input})
|
| 91 |
+
|
| 92 |
+
prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
|
| 93 |
+
prompt_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device='cuda')
|
| 94 |
+
|
| 95 |
+
## Chat in AR Mode
|
| 96 |
+
out_ids, nfe = model.ar_generate(inputs.input_ids, max_new_tokens=512)
|
| 97 |
+
|
| 98 |
+
## Chat in dLM Mode
|
| 99 |
+
out_ids, nfe = model.generate(prompt_ids, max_new_tokens=512, block_length=32, threshold=0.9, eos_token_id=tokenizer.eos_token_id)
|
| 100 |
+
|
| 101 |
+
## Chat in Linear Self-Speculation Mode
|
| 102 |
+
out_ids, nfe = model.linear_spec_generate(prompt_ids, max_new_tokens=512, block_length=32, eos_token_id=tokenizer.eos_token_id)
|
| 103 |
+
|
| 104 |
+
tokenized_out = tokenizer.batch_decode(out_ids[:, prompt_ids.shape[1]:], skip_special_tokens=True)[0]
|
| 105 |
+
print(f"Model: {tokenized_out}")
|
| 106 |
+
print(f"[Num Function Eval (NFE)={nfe}]")
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
## Inference with Linear Self-Speculation + LoRA-enhanced Drafter
|
| 112 |
+
|
| 113 |
+
An optional LoRA adatper can be applied to the diffusion drafter in the linear self-speculation mode to further increase the acceptance length:
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
```python
|
| 117 |
+
import torch
|
| 118 |
+
from transformers import AutoModel, AutoTokenizer
|
| 119 |
+
from peft import PeftModel
|
| 120 |
+
|
| 121 |
+
repo = "nvidia/Nemotron-Labs-Diffusion-8B-Base"
|
| 122 |
+
tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True)
|
| 123 |
+
model = AutoModel.from_pretrained(repo, trust_remote_code=True)
|
| 124 |
+
model = model.cuda().to(torch.bfloat16)
|
| 125 |
+
|
| 126 |
+
# Attach the linear_spec LoRA adapter.
|
| 127 |
+
model = PeftModel.from_pretrained(model, repo, subfolder="linear_spec_lora").eval()
|
| 128 |
+
# Unwrap so we can call linear_spec_generate directly (it toggles LoRA internally).
|
| 129 |
+
base = model.model
|
| 130 |
+
|
| 131 |
+
history = [{"role": "user", "content": "Solve: What is 15% of 240?"}]
|
| 132 |
+
prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
|
| 133 |
+
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
|
| 134 |
+
|
| 135 |
+
out_ids, nfe = base.linear_spec_generate(
|
| 136 |
+
prompt_ids, max_new_tokens=512, block_length=32,
|
| 137 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 138 |
+
)
|
| 139 |
+
print(tokenizer.decode(out_ids[0, prompt_ids.shape[1]:], skip_special_tokens=True))
|
| 140 |
+
print(f"[NFE={nfe}]")
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
## Ethical Considerations
|
| 145 |
+
NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. For more detailed information on ethical considerations for this model, please see the [bias](./model_cards/bias.md), [explainability](./model_cards/explainability.md), [safety & security](./model_cards/safety.md), and [privacy](./model_cards/privacy.md) subcards.
|
| 146 |
+
|
| 147 |
+
Please report model quality, risk, security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
## Citations
|
| 151 |
+
|
| 152 |
+
```bibtex
|
| 153 |
+
@techreport{fu2026nemotronlabsdiffusion,
|
| 154 |
+
title = {Nemotron-Labs-Diffusion: A Tri-Mode Language Model Unifying Autoregressive, Diffusion, and Self-Speculation Decoding},
|
| 155 |
+
author = {Yonggan Fu and Lexington Whalen and Abhinav Garg and Chengyue Wu and Maksim Khadkevich and Nicolai Oswald and Enze Xie and Daniel Egert and Sharath Turuvekere Sreenivas and Shizhe Diao and Chenhan Yu and Ye Yu and Weijia Chen and Sajad Norouzi and Shiyi Lan and Ligeng Zhu and Jin Wang and Jindong Jiang and Morteza Mardani and Mehran Maghoumi and Song Han and Ante Jukic and Nima Tajbakhsh and Jan Kautz and Pavlo Molchanov},
|
| 156 |
+
institution = {NVIDIA},
|
| 157 |
+
year = {2026},
|
| 158 |
+
note = {Technical report}
|
| 159 |
+
}
|
| 160 |
+
```
|
assets/demo.gif
ADDED
|
Git LFS Details
|
assets/demo.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:666d8785ac4af75931d9c677757c4ef9945bf114d07f1c4e2ebb7b893ac39006
|
| 3 |
+
size 9454873
|
assets/result_acc.png
ADDED
|
Git LFS Details
|
assets/result_efficiency.png
ADDED
|
Git LFS Details
|
assets/teaser.png
ADDED
|
Git LFS Details
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{{'<SPECIAL_10>System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'
|
| 2 |
+
' + message['content'].strip()}}{% endif %}{% endfor %}{{'
|
| 3 |
+
'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '
|
| 4 |
+
<SPECIAL_11>User
|
| 5 |
+
' + message['content'].strip() + '
|
| 6 |
+
<SPECIAL_11>Assistant
|
| 7 |
+
' }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() }}{% endif %}{% endfor %}
|
config.json
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"ar_loss_weight": 1.0,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"NemotronLabsDiffusionModel"
|
| 5 |
+
],
|
| 6 |
+
"attention_bias": false,
|
| 7 |
+
"attention_dropout": 0.0,
|
| 8 |
+
"attn_implementation": "sdpa",
|
| 9 |
+
"auto_map": {
|
| 10 |
+
"AutoConfig": "configuration_nemotron_labs_diffusion.NemotronLabsDiffusionConfig",
|
| 11 |
+
"AutoModel": "modeling_nemotron_labs_diffusion.NemotronLabsDiffusionModel"
|
| 12 |
+
},
|
| 13 |
+
"block_size": 32,
|
| 14 |
+
"bos_token_id": 1,
|
| 15 |
+
"dlm_loss_weight": null,
|
| 16 |
+
"dlm_paradigm": "bidirectional",
|
| 17 |
+
"dp_varying_mask_ratio": false,
|
| 18 |
+
"eos_token_id": 2,
|
| 19 |
+
"head_dim": 128,
|
| 20 |
+
"hidden_act": "silu",
|
| 21 |
+
"hidden_size": 4096,
|
| 22 |
+
"initializer_range": 0.02,
|
| 23 |
+
"intermediate_size": 14336,
|
| 24 |
+
"mask_token_id": 100,
|
| 25 |
+
"max_position_embeddings": 4096,
|
| 26 |
+
"mlp_bias": false,
|
| 27 |
+
"model_type": "nemotron_labs_diffusion",
|
| 28 |
+
"num_attention_heads": 32,
|
| 29 |
+
"num_hidden_layers": 34,
|
| 30 |
+
"num_key_value_heads": 8,
|
| 31 |
+
"rms_norm_eps": 1e-05,
|
| 32 |
+
"rope_parameters": {
|
| 33 |
+
"beta_fast": 32.0,
|
| 34 |
+
"beta_slow": 1.0,
|
| 35 |
+
"factor": 0.25,
|
| 36 |
+
"llama_4_scaling_beta": 0.1,
|
| 37 |
+
"mscale": 1.0,
|
| 38 |
+
"mscale_all_dim": 1.0,
|
| 39 |
+
"original_max_position_embeddings": 16384,
|
| 40 |
+
"rope_theta": 1000000.0,
|
| 41 |
+
"rope_type": "yarn"
|
| 42 |
+
},
|
| 43 |
+
"sliding_window": null,
|
| 44 |
+
"tie_word_embeddings": false,
|
| 45 |
+
"torch_dtype": "bfloat16",
|
| 46 |
+
"transformers_version": "5.0.0",
|
| 47 |
+
"use_cache": false,
|
| 48 |
+
"vocab_size": 131072
|
| 49 |
+
}
|
configuration_nemotron_labs_diffusion.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Nemotron-Labs Diffusion model configuration"""
|
| 16 |
+
|
| 17 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 18 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 19 |
+
from transformers.utils import logging
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class NemotronLabsDiffusionConfig(PretrainedConfig):
|
| 26 |
+
r"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`NemotronLabsDiffusionModel`] for diffusion language models.
|
| 28 |
+
It is used to instantiate a NemotronLabsDiffusionModel according to the specified arguments, defining the model architecture.
|
| 29 |
+
|
| 30 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 31 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
vocab_size (`int`, *optional*, defaults to 131072):
|
| 35 |
+
Vocabulary size of the Ministral model.
|
| 36 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 37 |
+
Dimension of the hidden representations.
|
| 38 |
+
intermediate_size (`int`, *optional*, defaults to 14336):
|
| 39 |
+
Dimension of the MLP representations.
|
| 40 |
+
num_hidden_layers (`int`, *optional*, defaults to 34):
|
| 41 |
+
Number of hidden layers in the Transformer decoder.
|
| 42 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 43 |
+
Number of attention heads for each attention layer.
|
| 44 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
| 45 |
+
Number of key_value heads for Grouped Query Attention.
|
| 46 |
+
head_dim (`int`, *optional*, defaults to 128):
|
| 47 |
+
The attention head dimension.
|
| 48 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 49 |
+
The non-linear activation function.
|
| 50 |
+
max_position_embeddings (`int`, *optional*, defaults to 262144):
|
| 51 |
+
The maximum sequence length.
|
| 52 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 53 |
+
The standard deviation of the truncated_normal_initializer.
|
| 54 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 55 |
+
The epsilon used by the rms normalization layers.
|
| 56 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 57 |
+
Whether or not the model should return the last key/values attentions.
|
| 58 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 59 |
+
Whether the model's input and output word embeddings should be tied.
|
| 60 |
+
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
| 61 |
+
The base period of the RoPE embeddings.
|
| 62 |
+
rope_parameters (`Dict`, *optional*):
|
| 63 |
+
Dictionary containing the scaling configuration for the RoPE embeddings.
|
| 64 |
+
Default uses YaRN scaling with factor=16, original_max_position_embeddings=16384.
|
| 65 |
+
attention_bias (`bool`, defaults to `False`):
|
| 66 |
+
Whether to use a bias in the query, key, value and output projection layers.
|
| 67 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 68 |
+
The dropout ratio for the attention probabilities.
|
| 69 |
+
mlp_bias (`bool`, *optional*, defaults to `False`):
|
| 70 |
+
Whether to use a bias in up_proj, down_proj and gate_proj layers.
|
| 71 |
+
sliding_window (`int`, *optional*, defaults to None):
|
| 72 |
+
Sliding window attention size.
|
| 73 |
+
mask_token_id (`int`, *optional*, defaults to -1):
|
| 74 |
+
Token ID for masking in diffusion.
|
| 75 |
+
dlm_paradigm (`str`, *optional*, defaults to 'bidirectional'):
|
| 76 |
+
Paradigm for diffusion ('bidirectional', 'autoregressive', 'block_diff').
|
| 77 |
+
block_size (`int`, *optional*, defaults to 32):
|
| 78 |
+
Block size for block diffusion paradigms.
|
| 79 |
+
dlm_loss_weight (`float`, *optional*):
|
| 80 |
+
Weight for diffusion LM loss.
|
| 81 |
+
ar_loss_weight (`float`, *optional*, defaults to 1.0):
|
| 82 |
+
Weight for autoregressive loss in block_diff paradigm. Use 10000 to only use AR loss.
|
| 83 |
+
dp_varying_mask_ratio (`bool`, *optional*, defaults to False):
|
| 84 |
+
Whether to use varying mask ratio for each DP rank during sampling.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
model_type = "nemotron_labs_diffusion"
|
| 88 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 89 |
+
|
| 90 |
+
# Default tensor parallel plan for base model `Ministral`
|
| 91 |
+
base_model_tp_plan = {
|
| 92 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 93 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 94 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 95 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 96 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 97 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 98 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 99 |
+
}
|
| 100 |
+
base_model_pp_plan = {
|
| 101 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 102 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 103 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
vocab_size=131072,
|
| 109 |
+
hidden_size=4096,
|
| 110 |
+
intermediate_size=14336,
|
| 111 |
+
num_hidden_layers=34,
|
| 112 |
+
num_attention_heads=32,
|
| 113 |
+
num_key_value_heads=8,
|
| 114 |
+
head_dim=128,
|
| 115 |
+
hidden_act="silu",
|
| 116 |
+
max_position_embeddings=262144,
|
| 117 |
+
initializer_range=0.02,
|
| 118 |
+
rms_norm_eps=1e-05,
|
| 119 |
+
use_cache=True,
|
| 120 |
+
pad_token_id=None,
|
| 121 |
+
bos_token_id=1,
|
| 122 |
+
eos_token_id=2,
|
| 123 |
+
tie_word_embeddings=False,
|
| 124 |
+
rope_theta=1000000.0,
|
| 125 |
+
rope_parameters=None,
|
| 126 |
+
attention_bias=False,
|
| 127 |
+
attention_dropout=0.0,
|
| 128 |
+
mlp_bias=False,
|
| 129 |
+
sliding_window=None,
|
| 130 |
+
attn_implementation="sdpa",
|
| 131 |
+
mask_token_id=-1,
|
| 132 |
+
dlm_paradigm='bidirectional',
|
| 133 |
+
block_size=32,
|
| 134 |
+
dlm_loss_weight=None,
|
| 135 |
+
ar_loss_weight=1.0,
|
| 136 |
+
dp_varying_mask_ratio=False,
|
| 137 |
+
**kwargs,
|
| 138 |
+
):
|
| 139 |
+
self.vocab_size = vocab_size
|
| 140 |
+
self.max_position_embeddings = max_position_embeddings
|
| 141 |
+
self.hidden_size = hidden_size
|
| 142 |
+
self.intermediate_size = intermediate_size
|
| 143 |
+
self.num_hidden_layers = num_hidden_layers
|
| 144 |
+
self.num_attention_heads = num_attention_heads
|
| 145 |
+
|
| 146 |
+
# for backward compatibility
|
| 147 |
+
if num_key_value_heads is None:
|
| 148 |
+
num_key_value_heads = num_attention_heads
|
| 149 |
+
|
| 150 |
+
self.num_key_value_heads = num_key_value_heads
|
| 151 |
+
self.head_dim = head_dim
|
| 152 |
+
self.hidden_act = hidden_act
|
| 153 |
+
self.initializer_range = initializer_range
|
| 154 |
+
self.rms_norm_eps = rms_norm_eps
|
| 155 |
+
self.use_cache = use_cache
|
| 156 |
+
self.rope_parameters = rope_parameters
|
| 157 |
+
# `rope_theta` is read at the top level by transformers v4.55's yarn impl; mirror from rope_parameters when present.
|
| 158 |
+
self.rope_theta = (rope_parameters or {}).get("rope_theta", rope_theta)
|
| 159 |
+
# v4.55 reads rope params from `rope_scaling`; in v5.0 `rope_scaling` is a property alias for rope_parameters.
|
| 160 |
+
self.rope_scaling = rope_parameters
|
| 161 |
+
self.attention_bias = attention_bias
|
| 162 |
+
self.attention_dropout = attention_dropout
|
| 163 |
+
self.mlp_bias = mlp_bias
|
| 164 |
+
self.sliding_window = sliding_window
|
| 165 |
+
|
| 166 |
+
rope_config_validation(self)
|
| 167 |
+
|
| 168 |
+
self.attn_implementation = attn_implementation
|
| 169 |
+
|
| 170 |
+
self.mask_token_id = mask_token_id
|
| 171 |
+
self.dlm_paradigm = dlm_paradigm
|
| 172 |
+
self.block_size = block_size
|
| 173 |
+
self.dlm_loss_weight = dlm_loss_weight
|
| 174 |
+
self.ar_loss_weight = ar_loss_weight
|
| 175 |
+
self.dp_varying_mask_ratio = dp_varying_mask_ratio
|
| 176 |
+
super().__init__(
|
| 177 |
+
pad_token_id=pad_token_id,
|
| 178 |
+
bos_token_id=bos_token_id,
|
| 179 |
+
eos_token_id=eos_token_id,
|
| 180 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 181 |
+
**kwargs,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
__all__ = ["NemotronLabsDiffusionConfig"]
|
| 186 |
+
|
generation_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"transformers_version": "5.0.0",
|
| 6 |
+
"use_cache": false
|
| 7 |
+
}
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:73af2cd1c982f85bac01c7da43765deb3f2deced76eb93dbd2a6a968ff531349
|
| 3 |
+
size 16979144720
|
model_cards/bias.md
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Field | Response
|
| 2 |
+
:---------------------------------------------------------------------------------------------------|:---------------
|
| 3 |
+
Participation considerations from adversely impacted groups [protected classes](https://www.senate.ca.gov/content/protected-classes) in model design and testing: | [None]
|
| 4 |
+
Measures taken to mitigate against unwanted bias: | [None]
|
model_cards/explainability.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Field | Response
|
| 2 |
+
:------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------
|
| 3 |
+
Intended Task/Domain: | Text generation
|
| 4 |
+
Model Type: | Transformer
|
| 5 |
+
Intended Users: | Generative AI creators working with conversational AI models.
|
| 6 |
+
Output: | Text (Responds to posed question, Stateful - remembers previous answers)
|
| 7 |
+
Describe how the model works: | Text input is encoded into tokens and passed into a transformer-based language model, which returns a text response.
|
| 8 |
+
Name the adversely impacted groups this has been tested to deliver comparable outcomes regardless of: | Not Applicable
|
| 9 |
+
Technical Limitations & Mitigation: | The model cannot perform long-horizon reasoning and tool calling.
|
| 10 |
+
Verified to have met prescribed NVIDIA quality standards: | Yes
|
| 11 |
+
Performance Metrics: | Accuracy, Latency, Throughput
|
| 12 |
+
Potential Known Risks: | In some instances, the model may think too long and struggle to derive final answers. The model's output can generate all forms of text, including what may be considered toxic, offensive, or indecent.
|
| 13 |
+
Licensing: | nvidia-open-model-license.
|
model_cards/privacy.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Field | Response
|
| 2 |
+
:----------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------
|
| 3 |
+
Generatable or reverse engineerable personal data? | [No]
|
| 4 |
+
Personal data used to create this model? | [No]
|
| 5 |
+
Was consent obtained for any personal data used? | [Not Applicable]
|
| 6 |
+
How often is dataset reviewed? | [During dataset creation, model training, evaluation, and the prerelease phase.]
|
| 7 |
+
Was data from user interactions with the AI model (e.g. user input and prompts) used to train the model? | [Yes]
|
| 8 |
+
Is there provenance for all datasets used in training? | Yes
|
| 9 |
+
Does data labeling (annotation, metadata) comply with privacy laws? | Yes
|
| 10 |
+
Is data compliant with data subject requests for data correction or removal, if such a request was made? | Not Applicable.
|
| 11 |
+
Applicable Privacy Policy | https://www.nvidia.com/en-us/about-nvidia/privacy-policy/
|
model_cards/safety.md
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Field | Response
|
| 2 |
+
:---------------------------------------------------|:----------------------------------
|
| 3 |
+
Model Application Field(s): | [Media & Entertainment].
|
| 4 |
+
Describe the life critical impact (if present). | Not Applicable
|
| 5 |
+
Model and dataset restrictions: | The Principle of least privilege (PoLP) is applied limiting access for dataset generation and model development. Restrictions enforce dataset access during training, and dataset license constraints adhered to.
|
| 6 |
+
Use Case Restrictions: | Abide by nvidia-open-model-license.
|
modeling_ministral.py
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Callable
|
| 2 |
+
from typing import Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from transformers.utils.generic import check_model_inputs
|
| 8 |
+
|
| 9 |
+
from transformers.activations import ACT2FN
|
| 10 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 11 |
+
from transformers.generation import GenerationMixin
|
| 12 |
+
# from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
| 13 |
+
from transformers.integrations import use_kernel_forward_from_hub
|
| 14 |
+
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask, ALL_MASK_ATTENTION_FUNCTIONS
|
| 15 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 16 |
+
from transformers.modeling_layers import (
|
| 17 |
+
GenericForQuestionAnswering,
|
| 18 |
+
GenericForSequenceClassification,
|
| 19 |
+
GenericForTokenClassification,
|
| 20 |
+
GradientCheckpointingLayer,
|
| 21 |
+
)
|
| 22 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 23 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 24 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 25 |
+
from transformers.processing_utils import Unpack
|
| 26 |
+
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
|
| 27 |
+
# from transformers.utils.generic import maybe_autocast
|
| 28 |
+
from .configuration_nemotron_labs_diffusion import NemotronLabsDiffusionConfig
|
| 29 |
+
|
| 30 |
+
#ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
|
| 31 |
+
|
| 32 |
+
def rotate_half(x):
|
| 33 |
+
"""Rotates half the hidden dims of the input."""
|
| 34 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 35 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 36 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 37 |
+
|
| 38 |
+
# @use_kernel_func_from_hub("rotary_pos_emb")
|
| 39 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 40 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
q (`torch.Tensor`): The query tensor.
|
| 44 |
+
k (`torch.Tensor`): The key tensor.
|
| 45 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 46 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 47 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 48 |
+
Deprecated and unused.
|
| 49 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 50 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 51 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 52 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 53 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 54 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 55 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 56 |
+
Returns:
|
| 57 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 58 |
+
"""
|
| 59 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 60 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 61 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 62 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 63 |
+
return q_embed, k_embed
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 67 |
+
"""
|
| 68 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 69 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 70 |
+
"""
|
| 71 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 72 |
+
if n_rep == 1:
|
| 73 |
+
return hidden_states
|
| 74 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 75 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def eager_attention_forward(
|
| 79 |
+
module: nn.Module,
|
| 80 |
+
query: torch.Tensor,
|
| 81 |
+
key: torch.Tensor,
|
| 82 |
+
value: torch.Tensor,
|
| 83 |
+
attention_mask: Optional[torch.Tensor],
|
| 84 |
+
scaling: float,
|
| 85 |
+
dropout: float = 0.0,
|
| 86 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 87 |
+
):
|
| 88 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 89 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 90 |
+
|
| 91 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 92 |
+
if attention_mask is not None:
|
| 93 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 94 |
+
attn_weights = attn_weights + causal_mask
|
| 95 |
+
|
| 96 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 97 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 98 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 99 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 100 |
+
|
| 101 |
+
return attn_output, attn_weights
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor:
|
| 105 |
+
scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
|
| 106 |
+
return scaling.unsqueeze(-1)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# @use_kernelized_func(apply_rotary_pos_emb)
|
| 110 |
+
class Ministral3Attention(nn.Module):
|
| 111 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 112 |
+
|
| 113 |
+
def __init__(self, config: NemotronLabsDiffusionConfig, layer_idx: int):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.config = config
|
| 116 |
+
self.layer_idx = layer_idx
|
| 117 |
+
self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 118 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 119 |
+
self.scaling = self.head_dim**-0.5
|
| 120 |
+
self.attention_dropout = config.attention_dropout
|
| 121 |
+
self.is_causal = True
|
| 122 |
+
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
|
| 123 |
+
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
|
| 124 |
+
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
|
| 125 |
+
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
| 126 |
+
|
| 127 |
+
self.diffusion_lm = config.diffusion_lm
|
| 128 |
+
|
| 129 |
+
def forward(
|
| 130 |
+
self,
|
| 131 |
+
hidden_states: torch.Tensor,
|
| 132 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 133 |
+
attention_mask: Optional[torch.Tensor],
|
| 134 |
+
past_key_values: Optional[Cache] = None,
|
| 135 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 136 |
+
use_cache: Optional[bool] = False,
|
| 137 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 138 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 139 |
+
input_shape = hidden_states.shape[:-1]
|
| 140 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 141 |
+
|
| 142 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 143 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 144 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 145 |
+
|
| 146 |
+
cos, sin = position_embeddings
|
| 147 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 148 |
+
query_states = query_states * _get_llama_4_attn_scale(
|
| 149 |
+
cache_position,
|
| 150 |
+
self.config.rope_parameters.get("llama_4_scaling_beta"),
|
| 151 |
+
self.config.rope_parameters.get("original_max_position_embeddings"),
|
| 152 |
+
).to(query_states.dtype)
|
| 153 |
+
|
| 154 |
+
if past_key_values is not None:
|
| 155 |
+
if use_cache:
|
| 156 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 157 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 158 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 159 |
+
else: ## if use_cache == False, do not update cache
|
| 160 |
+
old_k, old_v = past_key_values.layers[self.layer_idx].keys, past_key_values.layers[self.layer_idx].values
|
| 161 |
+
key_states = torch.cat([old_k, key_states], dim=-2)
|
| 162 |
+
value_states = torch.cat([old_v, value_states], dim=-2)
|
| 163 |
+
|
| 164 |
+
attention_interface: Callable = eager_attention_forward
|
| 165 |
+
if self.config._attn_implementation != "eager":
|
| 166 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 167 |
+
|
| 168 |
+
if self.diffusion_lm:
|
| 169 |
+
attn_output, attn_weights = attention_interface(
|
| 170 |
+
self,
|
| 171 |
+
query_states,
|
| 172 |
+
key_states,
|
| 173 |
+
value_states,
|
| 174 |
+
None,
|
| 175 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 176 |
+
scaling=self.scaling,
|
| 177 |
+
is_causal=False,
|
| 178 |
+
**kwargs,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
else:
|
| 182 |
+
attn_output, attn_weights = attention_interface(
|
| 183 |
+
self,
|
| 184 |
+
query_states,
|
| 185 |
+
key_states,
|
| 186 |
+
value_states,
|
| 187 |
+
attention_mask,
|
| 188 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 189 |
+
scaling=self.scaling,
|
| 190 |
+
sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
|
| 191 |
+
**kwargs,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 195 |
+
attn_output = self.o_proj(attn_output)
|
| 196 |
+
return attn_output, attn_weights
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class Ministral3MLP(nn.Module):
|
| 200 |
+
def __init__(self, config):
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.config = config
|
| 203 |
+
self.hidden_size = config.hidden_size
|
| 204 |
+
self.intermediate_size = config.intermediate_size
|
| 205 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 206 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 207 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 208 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 209 |
+
|
| 210 |
+
def forward(self, x):
|
| 211 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 212 |
+
return down_proj
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 216 |
+
class Ministral3RMSNorm(nn.Module):
|
| 217 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 218 |
+
"""
|
| 219 |
+
Ministral3RMSNorm is equivalent to T5LayerNorm
|
| 220 |
+
"""
|
| 221 |
+
super().__init__()
|
| 222 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 223 |
+
self.variance_epsilon = eps
|
| 224 |
+
|
| 225 |
+
def forward(self, hidden_states):
|
| 226 |
+
input_dtype = hidden_states.dtype
|
| 227 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 228 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 229 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 230 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 231 |
+
|
| 232 |
+
def extra_repr(self):
|
| 233 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class Ministral3DecoderLayer(GradientCheckpointingLayer):
|
| 237 |
+
def __init__(self, config: NemotronLabsDiffusionConfig, layer_idx: int):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.hidden_size = config.hidden_size
|
| 240 |
+
|
| 241 |
+
if hasattr(config, 'attn_class'):
|
| 242 |
+
attn_class = config.attn_class
|
| 243 |
+
else:
|
| 244 |
+
attn_class = Ministral3Attention
|
| 245 |
+
|
| 246 |
+
self.self_attn = attn_class(config=config, layer_idx=layer_idx)
|
| 247 |
+
self.mlp = Ministral3MLP(config)
|
| 248 |
+
self.input_layernorm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 249 |
+
self.post_attention_layernorm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 250 |
+
|
| 251 |
+
def forward(
|
| 252 |
+
self,
|
| 253 |
+
hidden_states: torch.Tensor,
|
| 254 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 255 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 256 |
+
past_key_values: Optional[Cache] = None,
|
| 257 |
+
use_cache: Optional[bool] = False,
|
| 258 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 259 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 260 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 261 |
+
) -> torch.Tensor:
|
| 262 |
+
residual = hidden_states
|
| 263 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 264 |
+
# Self Attention
|
| 265 |
+
hidden_states, _ = self.self_attn(
|
| 266 |
+
hidden_states=hidden_states,
|
| 267 |
+
attention_mask=attention_mask,
|
| 268 |
+
position_ids=position_ids,
|
| 269 |
+
past_key_values=past_key_values,
|
| 270 |
+
use_cache=use_cache,
|
| 271 |
+
cache_position=cache_position,
|
| 272 |
+
position_embeddings=position_embeddings,
|
| 273 |
+
**kwargs,
|
| 274 |
+
)
|
| 275 |
+
hidden_states = residual + hidden_states
|
| 276 |
+
|
| 277 |
+
# Fully Connected
|
| 278 |
+
residual = hidden_states
|
| 279 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 280 |
+
hidden_states = self.mlp(hidden_states)
|
| 281 |
+
hidden_states = residual + hidden_states
|
| 282 |
+
return hidden_states
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
@auto_docstring
|
| 286 |
+
class Ministral3PreTrainedModel(PreTrainedModel):
|
| 287 |
+
config: NemotronLabsDiffusionConfig
|
| 288 |
+
base_model_prefix = "model"
|
| 289 |
+
supports_gradient_checkpointing = True
|
| 290 |
+
_no_split_modules = ["Ministral3DecoderLayer"]
|
| 291 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 292 |
+
_supports_flash_attn = True
|
| 293 |
+
_supports_sdpa = True
|
| 294 |
+
_supports_flex_attn = True
|
| 295 |
+
|
| 296 |
+
_can_compile_fullgraph = True
|
| 297 |
+
_supports_attention_backend = True
|
| 298 |
+
_can_record_outputs = {
|
| 299 |
+
"hidden_states": Ministral3DecoderLayer,
|
| 300 |
+
"attentions": Ministral3Attention,
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class Ministral3RotaryEmbedding(nn.Module):
|
| 305 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 306 |
+
|
| 307 |
+
def __init__(self, config: NemotronLabsDiffusionConfig, device=None):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 310 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 311 |
+
|
| 312 |
+
self.config = config
|
| 313 |
+
|
| 314 |
+
self.rope_type = self.config.rope_parameters["rope_type"]
|
| 315 |
+
rope_init_fn: Callable = self.compute_default_rope_parameters
|
| 316 |
+
if self.rope_type != "default":
|
| 317 |
+
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 318 |
+
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
| 319 |
+
|
| 320 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 321 |
+
self.original_inv_freq = inv_freq
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
@staticmethod
|
| 325 |
+
def compute_default_rope_parameters(
|
| 326 |
+
config: Optional[NemotronLabsDiffusionConfig] = None,
|
| 327 |
+
device: Optional["torch.device"] = None,
|
| 328 |
+
seq_len: Optional[int] = None,
|
| 329 |
+
) -> tuple["torch.Tensor", float]:
|
| 330 |
+
"""
|
| 331 |
+
Computes the inverse frequencies according to the original RoPE implementation
|
| 332 |
+
Args:
|
| 333 |
+
config ([`~transformers.PreTrainedConfig`]):
|
| 334 |
+
The model configuration.
|
| 335 |
+
device (`torch.device`):
|
| 336 |
+
The device to use for initialization of the inverse frequencies.
|
| 337 |
+
seq_len (`int`, *optional*):
|
| 338 |
+
The current sequence length. Unused for this type of RoPE.
|
| 339 |
+
Returns:
|
| 340 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 341 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 342 |
+
"""
|
| 343 |
+
base = config.rope_parameters["rope_theta"]
|
| 344 |
+
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 345 |
+
|
| 346 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 347 |
+
|
| 348 |
+
# Compute the inverse frequencies
|
| 349 |
+
inv_freq = 1.0 / (
|
| 350 |
+
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
| 351 |
+
)
|
| 352 |
+
return inv_freq, attention_factor
|
| 353 |
+
|
| 354 |
+
@torch.no_grad()
|
| 355 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 356 |
+
def forward(self, x, position_ids):
|
| 357 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 358 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 359 |
+
|
| 360 |
+
# device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 361 |
+
# with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
| 362 |
+
|
| 363 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 364 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 365 |
+
cos = emb.cos() * self.attention_scaling
|
| 366 |
+
sin = emb.sin() * self.attention_scaling
|
| 367 |
+
|
| 368 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@auto_docstring
|
| 372 |
+
class Ministral3Model(Ministral3PreTrainedModel):
|
| 373 |
+
def __init__(self, config: NemotronLabsDiffusionConfig):
|
| 374 |
+
super().__init__(config)
|
| 375 |
+
self.padding_idx = config.pad_token_id
|
| 376 |
+
self.vocab_size = config.vocab_size
|
| 377 |
+
|
| 378 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 379 |
+
self.layers = nn.ModuleList(
|
| 380 |
+
[Ministral3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 381 |
+
)
|
| 382 |
+
self.norm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 383 |
+
self.rotary_emb = Ministral3RotaryEmbedding(config=config)
|
| 384 |
+
self.gradient_checkpointing = False
|
| 385 |
+
|
| 386 |
+
# Initialize weights and apply final processing
|
| 387 |
+
self.post_init()
|
| 388 |
+
|
| 389 |
+
@check_model_inputs
|
| 390 |
+
@auto_docstring
|
| 391 |
+
def forward(
|
| 392 |
+
self,
|
| 393 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 394 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 395 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 396 |
+
past_key_values: Optional[Cache] = None,
|
| 397 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 398 |
+
use_cache: Optional[bool] = None,
|
| 399 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 400 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 401 |
+
) -> BaseModelOutputWithPast:
|
| 402 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 403 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 404 |
+
|
| 405 |
+
if inputs_embeds is None:
|
| 406 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 407 |
+
|
| 408 |
+
if use_cache and past_key_values is None:
|
| 409 |
+
# past_key_values = DynamicCache(config=self.config)
|
| 410 |
+
past_key_values = DynamicCache()
|
| 411 |
+
|
| 412 |
+
if cache_position is None:
|
| 413 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 414 |
+
cache_position = torch.arange(
|
| 415 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
if position_ids is None:
|
| 419 |
+
position_ids = cache_position.unsqueeze(0)
|
| 420 |
+
|
| 421 |
+
if kwargs.get("use_causal_mask", False):
|
| 422 |
+
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
|
| 423 |
+
causal_mask = mask_function(
|
| 424 |
+
config=self.config,
|
| 425 |
+
input_embeds=inputs_embeds,
|
| 426 |
+
attention_mask=attention_mask,
|
| 427 |
+
cache_position=cache_position,
|
| 428 |
+
past_key_values=past_key_values,
|
| 429 |
+
position_ids=position_ids,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
else:
|
| 433 |
+
causal_mask = None
|
| 434 |
+
|
| 435 |
+
hidden_states = inputs_embeds
|
| 436 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
|
| 437 |
+
|
| 438 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 439 |
+
hidden_states = decoder_layer(
|
| 440 |
+
hidden_states,
|
| 441 |
+
attention_mask=causal_mask,
|
| 442 |
+
position_ids=position_ids,
|
| 443 |
+
past_key_values=past_key_values,
|
| 444 |
+
use_cache=use_cache,
|
| 445 |
+
cache_position=cache_position,
|
| 446 |
+
position_embeddings=position_embeddings,
|
| 447 |
+
**kwargs,
|
| 448 |
+
)
|
| 449 |
+
hidden_states = self.norm(hidden_states)
|
| 450 |
+
return BaseModelOutputWithPast(
|
| 451 |
+
last_hidden_state=hidden_states,
|
| 452 |
+
past_key_values=past_key_values if use_cache else None,
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
__all__ = [
|
| 457 |
+
"Ministral3Model",
|
| 458 |
+
"Ministral3PreTrainedModel",
|
| 459 |
+
]
|
modeling_nemotron_labs_diffusion.py
ADDED
|
@@ -0,0 +1,870 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import nn
|
| 9 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
|
| 10 |
+
from transformers.utils import ModelOutput
|
| 11 |
+
|
| 12 |
+
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
| 13 |
+
|
| 14 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 15 |
+
|
| 16 |
+
from transformers.processing_utils import Unpack
|
| 17 |
+
|
| 18 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 19 |
+
|
| 20 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 21 |
+
|
| 22 |
+
from transformers.generation import GenerationMixin
|
| 23 |
+
|
| 24 |
+
import math
|
| 25 |
+
|
| 26 |
+
from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
|
| 27 |
+
from .configuration_nemotron_labs_diffusion import NemotronLabsDiffusionConfig
|
| 28 |
+
|
| 29 |
+
__all__ = ["NemotronLabsDiffusionModel", "NemotronLabsDiffusionFlexAttention"]
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class NemotronLabsDiffusionOutputWithPast(ModelOutput):
|
| 33 |
+
loss: torch.FloatTensor | None = None
|
| 34 |
+
logits: torch.FloatTensor | None = None
|
| 35 |
+
causal_logits: torch.FloatTensor | None = None
|
| 36 |
+
past_key_values: Cache | None = None
|
| 37 |
+
hidden_states: tuple[torch.FloatTensor, ...] | None = None
|
| 38 |
+
attentions: tuple[torch.FloatTensor, ...] | None = None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs", dynamic=False)
|
| 42 |
+
def fused_flex_attention(q, k, v, block_mask=None):
|
| 43 |
+
return flex_attention(q, k, v, block_mask=block_mask)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class NemotronLabsDiffusionFlexAttention(Ministral3Attention):
|
| 47 |
+
def __init__(self, *args, **kwargs):
|
| 48 |
+
super().__init__(*args, **kwargs)
|
| 49 |
+
|
| 50 |
+
self.block_size = self.config.block_size
|
| 51 |
+
self.block_diff_mask = None
|
| 52 |
+
|
| 53 |
+
import torch._dynamo.config as dcfg
|
| 54 |
+
dcfg.cache_size_limit = 512
|
| 55 |
+
|
| 56 |
+
def compute_block_mask(self, mode, q_len, block_size=None):
|
| 57 |
+
|
| 58 |
+
def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
|
| 59 |
+
x0_flag_q = (q_idx >= n)
|
| 60 |
+
x0_flag_kv = (kv_idx >= n)
|
| 61 |
+
|
| 62 |
+
# Compute block indices
|
| 63 |
+
block_q = torch.where(x0_flag_q == 1,
|
| 64 |
+
(q_idx - n) // block_size,
|
| 65 |
+
q_idx // block_size)
|
| 66 |
+
block_kv = torch.where(x0_flag_kv == 1,
|
| 67 |
+
(kv_idx - n) // block_size,
|
| 68 |
+
kv_idx // block_size)
|
| 69 |
+
|
| 70 |
+
# **1. Block Diagonal Mask (M_BD) **
|
| 71 |
+
block_diagonal = (block_q == block_kv) & (x0_flag_kv == 0) & (x0_flag_q == 0)
|
| 72 |
+
|
| 73 |
+
# **2. Offset Block-Causal Mask (M_OBC) **
|
| 74 |
+
offset_block_causal = (
|
| 75 |
+
(block_q > block_kv)
|
| 76 |
+
& (x0_flag_kv == 1)
|
| 77 |
+
& (x0_flag_q == 0)
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# **3. Fully Causal Mask (M_BC) **
|
| 81 |
+
fully_causal = (q_idx >= kv_idx) & (x0_flag_kv == 1) & (x0_flag_q == 1)
|
| 82 |
+
|
| 83 |
+
# **4. Combine Masks **
|
| 84 |
+
return block_diagonal | offset_block_causal | fully_causal
|
| 85 |
+
|
| 86 |
+
attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, q_len//2)
|
| 87 |
+
|
| 88 |
+
block_mask = create_block_mask(
|
| 89 |
+
attn_mask, B=None, H=None, Q_LEN=q_len, KV_LEN=q_len
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return block_mask
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def forward(
|
| 96 |
+
self,
|
| 97 |
+
hidden_states: torch.Tensor,
|
| 98 |
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
| 99 |
+
attention_mask: Optional[torch.Tensor],
|
| 100 |
+
past_key_values: Optional[Cache] = None,
|
| 101 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 102 |
+
is_training: bool = True,
|
| 103 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 104 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 105 |
+
bsz, q_len, _ = hidden_states.size()
|
| 106 |
+
input_shape = hidden_states.shape[:-1]
|
| 107 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 108 |
+
|
| 109 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 110 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 111 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 112 |
+
|
| 113 |
+
cos, sin = position_embeddings
|
| 114 |
+
|
| 115 |
+
if is_training:
|
| 116 |
+
# Split query and key states in half along sequence length dimension
|
| 117 |
+
q1, q2 = query_states.chunk(2, dim=2)
|
| 118 |
+
k1, k2 = key_states.chunk(2, dim=2)
|
| 119 |
+
|
| 120 |
+
# Apply RoPE independently to each half
|
| 121 |
+
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
| 122 |
+
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
| 123 |
+
|
| 124 |
+
# Recombine the halves
|
| 125 |
+
query_states = torch.cat([q1, q2], dim=2)
|
| 126 |
+
key_states = torch.cat([k1, k2], dim=2)
|
| 127 |
+
else:
|
| 128 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 129 |
+
|
| 130 |
+
query_states = query_states * _get_llama_4_attn_scale(
|
| 131 |
+
cache_position,
|
| 132 |
+
self.config.rope_parameters.get("llama_4_scaling_beta"),
|
| 133 |
+
self.config.rope_parameters.get("original_max_position_embeddings"),
|
| 134 |
+
).to(query_states.dtype)
|
| 135 |
+
|
| 136 |
+
if past_key_values is not None:
|
| 137 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 138 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 139 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 140 |
+
|
| 141 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 142 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 143 |
+
|
| 144 |
+
if self.block_diff_mask is None or q_len != self.block_diff_mask.shape[-2]:
|
| 145 |
+
block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
|
| 146 |
+
else:
|
| 147 |
+
block_mask = self.block_diff_mask
|
| 148 |
+
|
| 149 |
+
attn_output = fused_flex_attention(query_states, key_states, value_states, block_mask=block_mask)
|
| 150 |
+
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
|
| 151 |
+
|
| 152 |
+
attn_output = self.o_proj(attn_output)
|
| 153 |
+
|
| 154 |
+
return attn_output, None
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class NemotronLabsDiffusionModel(Ministral3PreTrainedModel, GenerationMixin):
|
| 158 |
+
"""
|
| 159 |
+
A single model with:
|
| 160 |
+
- a bidirectional encoder + diffusion‐LM head over A
|
| 161 |
+
- a causal decoder + LM head over B, conditioned on F_A
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
def __init__(self, config: NemotronLabsDiffusionConfig):
|
| 165 |
+
super().__init__(config)
|
| 166 |
+
|
| 167 |
+
self.mask_token_id = config.mask_token_id
|
| 168 |
+
|
| 169 |
+
diffusion_config = copy.deepcopy(config)
|
| 170 |
+
diffusion_config.diffusion_lm = True
|
| 171 |
+
|
| 172 |
+
if config.dlm_paradigm == 'block_diff':
|
| 173 |
+
diffusion_config.attn_class = NemotronLabsDiffusionFlexAttention
|
| 174 |
+
elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
|
| 175 |
+
diffusion_config.attn_class = Ministral3Attention
|
| 176 |
+
if config.dlm_paradigm == 'autoregressive':
|
| 177 |
+
diffusion_config.diffusion_lm = False
|
| 178 |
+
else:
|
| 179 |
+
raise ValueError(f"Unsupported DLM paradigm: {config.dlm_paradigm}")
|
| 180 |
+
|
| 181 |
+
self.encoder = Ministral3Model(diffusion_config)
|
| 182 |
+
self.diffusion_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 183 |
+
self.vocab_size = config.vocab_size
|
| 184 |
+
|
| 185 |
+
self.post_init()
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_input_embeddings(self):
|
| 189 |
+
return self.encoder.embed_tokens
|
| 190 |
+
|
| 191 |
+
def set_input_embeddings(self, value):
|
| 192 |
+
self.encoder.embed_tokens = value
|
| 193 |
+
|
| 194 |
+
def get_output_embeddings(self):
|
| 195 |
+
return self.diffusion_head
|
| 196 |
+
|
| 197 |
+
def set_output_embeddings(self, new_embeddings):
|
| 198 |
+
self.diffusion_head = new_embeddings
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def forward_process(self, input_ids, eps=1e-3, block_size=None, loss_mask=None):
|
| 202 |
+
b, l = input_ids.shape
|
| 203 |
+
device = input_ids.device
|
| 204 |
+
|
| 205 |
+
if self.config.dp_varying_mask_ratio:
|
| 206 |
+
# Enable different random seeds for each DP rank during sampling
|
| 207 |
+
import torch.distributed as dist
|
| 208 |
+
dp_rank = 0
|
| 209 |
+
if dist.is_initialized():
|
| 210 |
+
try:
|
| 211 |
+
dp_rank = dist.get_rank()
|
| 212 |
+
except Exception:
|
| 213 |
+
dp_rank = 0
|
| 214 |
+
# Use a local generator to avoid affecting global RNG state
|
| 215 |
+
generator = torch.Generator(device=device)
|
| 216 |
+
generator.manual_seed(torch.seed() + dp_rank)
|
| 217 |
+
else:
|
| 218 |
+
generator = None
|
| 219 |
+
|
| 220 |
+
t = torch.rand(b, device=device, generator=generator)
|
| 221 |
+
|
| 222 |
+
p_mask = (1 - eps) * t + eps # shape: (b,)
|
| 223 |
+
p_mask = p_mask[:, None].expand(-1, l) # shape: (b, l)
|
| 224 |
+
|
| 225 |
+
masked_indices = torch.rand((b, l), device=device) < p_mask
|
| 226 |
+
|
| 227 |
+
if loss_mask is not None:
|
| 228 |
+
masked_indices[loss_mask == 0] = 0
|
| 229 |
+
|
| 230 |
+
noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
|
| 231 |
+
|
| 232 |
+
return noisy_batch, masked_indices, p_mask
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def forward(
|
| 236 |
+
self,
|
| 237 |
+
input_ids: torch.LongTensor,
|
| 238 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 239 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 240 |
+
labels: Optional[torch.LongTensor] = None,
|
| 241 |
+
split_len: Optional[int] = None,
|
| 242 |
+
past_key_values: Optional[Cache] = None,
|
| 243 |
+
block_size: Optional[int] = None,
|
| 244 |
+
eps: float = 1e-3,
|
| 245 |
+
is_teacher: bool = False,
|
| 246 |
+
masked_indices: Optional[torch.Tensor] = None,
|
| 247 |
+
p_mask: Optional[torch.Tensor] = None,
|
| 248 |
+
teacher_logits: Optional[torch.Tensor] = None,
|
| 249 |
+
masked_indices_teacher: Optional[torch.Tensor] = None,
|
| 250 |
+
loss_mask: Optional[torch.Tensor] = None,
|
| 251 |
+
ce_loss_weight: float = 1.0,
|
| 252 |
+
output_last_hidden_states_only: bool = False,
|
| 253 |
+
skip_loss: bool = False,
|
| 254 |
+
**kwargs,
|
| 255 |
+
) -> CausalLMOutputWithPast:
|
| 256 |
+
|
| 257 |
+
batch_size, seq_len = input_ids.shape
|
| 258 |
+
|
| 259 |
+
if self.config.dlm_paradigm == 'block_diff':
|
| 260 |
+
if labels is not None and block_size is None:
|
| 261 |
+
block_size = self.config.block_size
|
| 262 |
+
elif self.config.dlm_paradigm not in ('bidirectional', 'autoregressive'):
|
| 263 |
+
raise ValueError(f"Unknown dLM paradigm: {self.config.dlm_paradigm}")
|
| 264 |
+
|
| 265 |
+
if labels is not None and self.config.dlm_paradigm != 'autoregressive':
|
| 266 |
+
if masked_indices is not None:
|
| 267 |
+
# assert p_mask is not None
|
| 268 |
+
|
| 269 |
+
if loss_mask is not None:
|
| 270 |
+
masked_indices[loss_mask == 0] = 0
|
| 271 |
+
|
| 272 |
+
noisy_inputs = torch.where(masked_indices, self.mask_token_id, input_ids)
|
| 273 |
+
|
| 274 |
+
else:
|
| 275 |
+
noisy_inputs, masked_indices, p_mask = self.forward_process(input_ids, eps=eps, block_size=block_size, loss_mask=loss_mask)
|
| 276 |
+
|
| 277 |
+
else:
|
| 278 |
+
noisy_inputs = input_ids
|
| 279 |
+
masked_indices = None
|
| 280 |
+
p_mask = None
|
| 281 |
+
|
| 282 |
+
input_ids_len = noisy_inputs.shape[1]
|
| 283 |
+
if labels is not None and self.config.dlm_paradigm == 'block_diff':
|
| 284 |
+
if position_ids is None:
|
| 285 |
+
position_ids = torch.arange(input_ids_len, device=noisy_inputs.device).unsqueeze(0)
|
| 286 |
+
noisy_inputs = torch.cat([noisy_inputs, input_ids], dim=1)
|
| 287 |
+
|
| 288 |
+
enc_out = self.encoder(
|
| 289 |
+
past_key_values=past_key_values,
|
| 290 |
+
input_ids=noisy_inputs,
|
| 291 |
+
attention_mask=attention_mask,
|
| 292 |
+
position_ids=position_ids,
|
| 293 |
+
is_training=(labels is not None),
|
| 294 |
+
**kwargs,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if output_last_hidden_states_only:
|
| 298 |
+
return BaseModelOutput(last_hidden_state=enc_out.last_hidden_state)
|
| 299 |
+
|
| 300 |
+
logits = self.diffusion_head(enc_out.last_hidden_state) # (batch, len_B, vocab)
|
| 301 |
+
causal_logits = None
|
| 302 |
+
|
| 303 |
+
if labels is not None and self.config.dlm_paradigm == 'block_diff':
|
| 304 |
+
causal_logits = logits[:, input_ids_len:]
|
| 305 |
+
logits = logits[:, :input_ids_len]
|
| 306 |
+
|
| 307 |
+
loss = None
|
| 308 |
+
if labels is not None and not skip_loss:
|
| 309 |
+
if self.config.dlm_paradigm == 'autoregressive':
|
| 310 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 311 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 312 |
+
|
| 313 |
+
if loss_mask is None:
|
| 314 |
+
loss_fct = CrossEntropyLoss()
|
| 315 |
+
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
| 316 |
+
shift_labels = shift_labels.view(-1)
|
| 317 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 318 |
+
|
| 319 |
+
else:
|
| 320 |
+
loss_mask = loss_mask[..., 1:].contiguous()
|
| 321 |
+
|
| 322 |
+
loss_fct = CrossEntropyLoss(reduction='none')
|
| 323 |
+
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
| 324 |
+
shift_labels = shift_labels.view(-1)
|
| 325 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 326 |
+
|
| 327 |
+
token_losses = loss_fct(shift_logits, shift_labels)
|
| 328 |
+
|
| 329 |
+
flat_loss_mask = loss_mask.reshape(-1)
|
| 330 |
+
loss = token_losses[flat_loss_mask == 1].sum() / flat_loss_mask.sum()
|
| 331 |
+
|
| 332 |
+
else:
|
| 333 |
+
# LLaDA-style diffusion loss on masked positions.
|
| 334 |
+
# Token-wise cross entropy loss on masked positions.
|
| 335 |
+
token_loss = torch.nn.functional.cross_entropy(
|
| 336 |
+
logits[masked_indices],
|
| 337 |
+
labels[masked_indices],
|
| 338 |
+
reduction='none'
|
| 339 |
+
) / p_mask[masked_indices]
|
| 340 |
+
|
| 341 |
+
num_mask_tokens = masked_indices.sum()
|
| 342 |
+
|
| 343 |
+
# global_loss_avg=True: loss is reduced externally by global token count.
|
| 344 |
+
loss = token_loss.sum()
|
| 345 |
+
|
| 346 |
+
if self.config.dlm_loss_weight is not None:
|
| 347 |
+
loss = self.config.dlm_loss_weight * loss
|
| 348 |
+
|
| 349 |
+
if self.config.dlm_paradigm == 'block_diff':
|
| 350 |
+
# AR-side loss for block-diffusion paradigm.
|
| 351 |
+
causal_logits = causal_logits[..., :-1, :].contiguous()
|
| 352 |
+
causal_logits = causal_logits.view(-1, causal_logits.size(-1))
|
| 353 |
+
causal_labels = labels[..., 1:].contiguous().view(-1)
|
| 354 |
+
|
| 355 |
+
loss_fct = CrossEntropyLoss(reduction='sum')
|
| 356 |
+
ar_loss = loss_fct(causal_logits, causal_labels)
|
| 357 |
+
|
| 358 |
+
self.loss_diffusion = loss.detach().item() / num_mask_tokens
|
| 359 |
+
self.loss_ar = ar_loss.detach().item() / seq_len
|
| 360 |
+
|
| 361 |
+
loss = loss + self.config.ar_loss_weight * ar_loss
|
| 362 |
+
|
| 363 |
+
# global_loss_avg=True: return (sum_loss, token_count) for external mean.
|
| 364 |
+
if self.config.dlm_paradigm == 'block_diff':
|
| 365 |
+
loss = (loss, num_mask_tokens + int(self.config.ar_loss_weight * seq_len))
|
| 366 |
+
else:
|
| 367 |
+
loss = (loss, num_mask_tokens)
|
| 368 |
+
|
| 369 |
+
return NemotronLabsDiffusionOutputWithPast(
|
| 370 |
+
loss=loss if not is_teacher else logits,
|
| 371 |
+
logits=logits,
|
| 372 |
+
causal_logits=causal_logits,
|
| 373 |
+
past_key_values=enc_out.past_key_values,
|
| 374 |
+
hidden_states=None,
|
| 375 |
+
attentions=None,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
@torch.no_grad()
|
| 380 |
+
def generate(
|
| 381 |
+
self,
|
| 382 |
+
prompt_ids: torch.Tensor,
|
| 383 |
+
max_new_tokens: int,
|
| 384 |
+
block_length: int,
|
| 385 |
+
threshold: Optional[float] = None,
|
| 386 |
+
causal_context: bool = True,
|
| 387 |
+
temperature: float = 0.0,
|
| 388 |
+
eos_token_id: Optional[int] = None,
|
| 389 |
+
max_thinking_tokens: Optional[int] = None,
|
| 390 |
+
end_think_token_id: Optional[int] = None,
|
| 391 |
+
):
|
| 392 |
+
"""Block-wise diffusion decoding with prefix-cached KV (LLaDA-style).
|
| 393 |
+
|
| 394 |
+
Each block: append `block_length` mask tokens, then iteratively unmask
|
| 395 |
+
by confidence top-k (with optional threshold). When `causal_context`,
|
| 396 |
+
the KV cache and the next-block seed are produced via a causal forward
|
| 397 |
+
between blocks (flipping `self_attn.diffusion_lm`), matching the AR
|
| 398 |
+
objective at block boundaries.
|
| 399 |
+
|
| 400 |
+
Returns (output_ids, nfe) — output_ids includes the prompt.
|
| 401 |
+
"""
|
| 402 |
+
if eos_token_id is None:
|
| 403 |
+
eos_token_id = getattr(self.config, "eos_token_id", None)
|
| 404 |
+
mask_id = self.mask_token_id
|
| 405 |
+
|
| 406 |
+
x_accum = prompt_ids.clone()
|
| 407 |
+
B = prompt_ids.shape[0]
|
| 408 |
+
|
| 409 |
+
assert max_new_tokens % block_length == 0
|
| 410 |
+
num_blocks = max_new_tokens // block_length
|
| 411 |
+
# one denoising step per generated token (matches legacy chat_utils call)
|
| 412 |
+
steps_per_block = block_length
|
| 413 |
+
|
| 414 |
+
nfe = 0
|
| 415 |
+
|
| 416 |
+
def _set_diffusion_lm(val: bool):
|
| 417 |
+
for layer in self.encoder.layers:
|
| 418 |
+
if hasattr(layer.self_attn, "diffusion_lm"):
|
| 419 |
+
layer.self_attn.diffusion_lm = val
|
| 420 |
+
|
| 421 |
+
# Initial causal prefill produces the KV cache and the next-block seed.
|
| 422 |
+
if causal_context:
|
| 423 |
+
_set_diffusion_lm(False)
|
| 424 |
+
output = self(prompt_ids, use_cache=True, use_causal_mask=causal_context)
|
| 425 |
+
past_key_values = output.past_key_values
|
| 426 |
+
if causal_context:
|
| 427 |
+
_set_diffusion_lm(True)
|
| 428 |
+
|
| 429 |
+
next_token = None
|
| 430 |
+
if causal_context:
|
| 431 |
+
last_logit = output.logits[:, -1, :]
|
| 432 |
+
if temperature > 0:
|
| 433 |
+
next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), num_samples=1)
|
| 434 |
+
else:
|
| 435 |
+
next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
|
| 436 |
+
|
| 437 |
+
for num_block in range(num_blocks):
|
| 438 |
+
mask_block = torch.full(
|
| 439 |
+
(B, block_length), mask_id, dtype=prompt_ids.dtype, device=prompt_ids.device,
|
| 440 |
+
)
|
| 441 |
+
if causal_context:
|
| 442 |
+
mask_block[:, 0] = next_token[:, 0]
|
| 443 |
+
|
| 444 |
+
x_accum = torch.cat([x_accum, mask_block], dim=1)
|
| 445 |
+
block_start = prompt_ids.size(1) + num_block * block_length
|
| 446 |
+
block_slice = slice(block_start, block_start + block_length)
|
| 447 |
+
|
| 448 |
+
# Thinking-budget enforcement: if we've passed max_thinking_tokens
|
| 449 |
+
# without an end-think marker, inject one into this block.
|
| 450 |
+
if end_think_token_id is not None and max_thinking_tokens is not None:
|
| 451 |
+
tokens_before = num_block * block_length
|
| 452 |
+
tokens_after = tokens_before + block_length
|
| 453 |
+
if tokens_after > max_thinking_tokens:
|
| 454 |
+
gen_so_far = x_accum[:, prompt_ids.size(1):block_start]
|
| 455 |
+
has_end_think = (
|
| 456 |
+
(gen_so_far == end_think_token_id).any(dim=1)
|
| 457 |
+
if gen_so_far.size(1) > 0
|
| 458 |
+
else torch.zeros(B, dtype=torch.bool, device=prompt_ids.device)
|
| 459 |
+
)
|
| 460 |
+
if not has_end_think.all():
|
| 461 |
+
offset = max(0, max_thinking_tokens - tokens_before)
|
| 462 |
+
inject_pos = block_start + offset
|
| 463 |
+
for b in range(B):
|
| 464 |
+
if not has_end_think[b]:
|
| 465 |
+
x_accum[b, inject_pos] = end_think_token_id
|
| 466 |
+
|
| 467 |
+
mask_block_idx0 = x_accum[:, block_slice] == mask_id
|
| 468 |
+
num_transfer_tokens = _get_num_transfer_tokens(mask_block_idx0, steps_per_block)
|
| 469 |
+
|
| 470 |
+
# Denoise the current block by repeated confidence-based unmasking.
|
| 471 |
+
for i in range(steps_per_block):
|
| 472 |
+
mask_block_idx = x_accum[:, block_slice] == mask_id
|
| 473 |
+
if mask_block_idx.sum() == 0:
|
| 474 |
+
break
|
| 475 |
+
|
| 476 |
+
nfe += 1
|
| 477 |
+
logits_block = self(
|
| 478 |
+
x_accum[:, block_slice],
|
| 479 |
+
past_key_values=past_key_values,
|
| 480 |
+
use_cache=False,
|
| 481 |
+
).logits
|
| 482 |
+
|
| 483 |
+
x0, transfer_idx = _get_transfer_index(
|
| 484 |
+
logits_block, temperature, mask_block_idx, x_accum[:, block_slice],
|
| 485 |
+
num_transfer_tokens=num_transfer_tokens[:, i], threshold=threshold,
|
| 486 |
+
)
|
| 487 |
+
cur = x_accum[:, block_slice].clone()
|
| 488 |
+
cur[transfer_idx] = x0[transfer_idx]
|
| 489 |
+
x_accum[:, block_slice] = cur
|
| 490 |
+
|
| 491 |
+
if eos_token_id is not None:
|
| 492 |
+
block_tokens = x_accum[:, block_slice]
|
| 493 |
+
eos_mask = block_tokens == eos_token_id
|
| 494 |
+
if eos_mask.any(dim=1).any():
|
| 495 |
+
after_eos = eos_mask.cumsum(dim=1).bool()
|
| 496 |
+
mask_before = (block_tokens == mask_id) & ~after_eos
|
| 497 |
+
if (eos_mask.any(dim=1) & ~mask_before.any(dim=1)).any():
|
| 498 |
+
break
|
| 499 |
+
|
| 500 |
+
# Post-block: causal forward over the block to update the KV cache
|
| 501 |
+
# and (when causal_context) sample the seed for the next block.
|
| 502 |
+
if causal_context:
|
| 503 |
+
_set_diffusion_lm(False)
|
| 504 |
+
output = self(
|
| 505 |
+
x_accum[:, block_slice],
|
| 506 |
+
past_key_values=past_key_values,
|
| 507 |
+
use_cache=True,
|
| 508 |
+
use_causal_mask=causal_context,
|
| 509 |
+
)
|
| 510 |
+
past_key_values = output.past_key_values
|
| 511 |
+
nfe += 1
|
| 512 |
+
|
| 513 |
+
if causal_context:
|
| 514 |
+
_set_diffusion_lm(True)
|
| 515 |
+
last_logit = output.logits[:, -1, :]
|
| 516 |
+
if temperature > 0:
|
| 517 |
+
next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), num_samples=1)
|
| 518 |
+
else:
|
| 519 |
+
next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
|
| 520 |
+
|
| 521 |
+
if eos_token_id is not None:
|
| 522 |
+
gen_so_far = x_accum[:, prompt_ids.size(1):]
|
| 523 |
+
is_eos = gen_so_far == eos_token_id
|
| 524 |
+
if is_eos.any(dim=1).all():
|
| 525 |
+
first_eos = is_eos.to(torch.int64).argmax(dim=1)
|
| 526 |
+
max_eos = first_eos.max().item()
|
| 527 |
+
return x_accum[:, : prompt_ids.size(1) + max_eos + 1], nfe
|
| 528 |
+
|
| 529 |
+
return x_accum, nfe
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
@torch.no_grad()
|
| 534 |
+
def ar_generate(
|
| 535 |
+
self,
|
| 536 |
+
prompt_ids: torch.Tensor,
|
| 537 |
+
max_new_tokens: int = 128,
|
| 538 |
+
temperature: float = 0.0,
|
| 539 |
+
eos_token_id: Optional[int] = None,
|
| 540 |
+
max_thinking_tokens: Optional[int] = None,
|
| 541 |
+
end_think_token_id: Optional[int] = None,
|
| 542 |
+
) -> tuple:
|
| 543 |
+
"""Autoregressive generation calling the encoder directly (injected by build_hf_tidar_repo).
|
| 544 |
+
|
| 545 |
+
Bypasses NemotronLabsDiffusionModel.forward() to avoid diffusion-specific
|
| 546 |
+
code paths. Calls self.encoder (Ministral3Model) with explicit cache_position,
|
| 547 |
+
position_ids, and use_cache so the KV cache and causal masking behave
|
| 548 |
+
identically to MistralForCausalLM / vLLM.
|
| 549 |
+
|
| 550 |
+
Returns:
|
| 551 |
+
(output_ids, nfe) where output_ids includes the prompt.
|
| 552 |
+
"""
|
| 553 |
+
for layer in self.encoder.layers:
|
| 554 |
+
if hasattr(layer.self_attn, 'diffusion_lm'):
|
| 555 |
+
layer.self_attn.diffusion_lm = False
|
| 556 |
+
|
| 557 |
+
if eos_token_id is None:
|
| 558 |
+
eos_token_id = getattr(self.config, 'eos_token_id', None)
|
| 559 |
+
|
| 560 |
+
device = prompt_ids.device
|
| 561 |
+
batch_size, prompt_len = prompt_ids.shape
|
| 562 |
+
|
| 563 |
+
past_key_values = DynamicCache()
|
| 564 |
+
cache_position = torch.arange(prompt_len, device=device)
|
| 565 |
+
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
|
| 566 |
+
|
| 567 |
+
enc_out = self.encoder(
|
| 568 |
+
input_ids=prompt_ids,
|
| 569 |
+
position_ids=position_ids,
|
| 570 |
+
past_key_values=past_key_values,
|
| 571 |
+
use_cache=True,
|
| 572 |
+
cache_position=cache_position,
|
| 573 |
+
)
|
| 574 |
+
past_key_values = enc_out.past_key_values
|
| 575 |
+
next_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
|
| 576 |
+
|
| 577 |
+
generated_tokens = []
|
| 578 |
+
nfe = 0
|
| 579 |
+
|
| 580 |
+
for step in range(max_new_tokens):
|
| 581 |
+
nfe += 1
|
| 582 |
+
|
| 583 |
+
if temperature > 0:
|
| 584 |
+
probs = torch.softmax(next_logit / temperature, dim=-1)
|
| 585 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 586 |
+
else:
|
| 587 |
+
next_token = torch.argmax(next_logit, dim=-1, keepdim=True)
|
| 588 |
+
|
| 589 |
+
# ---- thinking budget enforcement ----
|
| 590 |
+
if end_think_token_id is not None and max_thinking_tokens is not None:
|
| 591 |
+
if step >= max_thinking_tokens:
|
| 592 |
+
if generated_tokens:
|
| 593 |
+
gen_tensor = torch.cat(generated_tokens, dim=1)
|
| 594 |
+
has_end_think = (gen_tensor == end_think_token_id).any(dim=1)
|
| 595 |
+
else:
|
| 596 |
+
has_end_think = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
| 597 |
+
for b in range(batch_size):
|
| 598 |
+
if not has_end_think[b]:
|
| 599 |
+
next_token[b] = end_think_token_id
|
| 600 |
+
|
| 601 |
+
generated_tokens.append(next_token)
|
| 602 |
+
|
| 603 |
+
if eos_token_id is not None and (next_token == eos_token_id).all():
|
| 604 |
+
break
|
| 605 |
+
|
| 606 |
+
if step < max_new_tokens - 1:
|
| 607 |
+
cur_pos = prompt_len + step
|
| 608 |
+
step_cache_pos = torch.tensor([cur_pos], device=device)
|
| 609 |
+
step_pos_ids = step_cache_pos.unsqueeze(0).expand(batch_size, -1)
|
| 610 |
+
|
| 611 |
+
enc_out = self.encoder(
|
| 612 |
+
input_ids=next_token,
|
| 613 |
+
position_ids=step_pos_ids,
|
| 614 |
+
past_key_values=past_key_values,
|
| 615 |
+
use_cache=True,
|
| 616 |
+
cache_position=step_cache_pos,
|
| 617 |
+
)
|
| 618 |
+
past_key_values = enc_out.past_key_values
|
| 619 |
+
next_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
|
| 620 |
+
|
| 621 |
+
all_generated = torch.cat(generated_tokens, dim=1)
|
| 622 |
+
output_ids = torch.cat([prompt_ids, all_generated], dim=1)
|
| 623 |
+
return output_ids, nfe
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
@torch.no_grad()
|
| 627 |
+
def linear_spec_generate(
|
| 628 |
+
self,
|
| 629 |
+
prompt_ids: torch.Tensor,
|
| 630 |
+
max_new_tokens: int = 128,
|
| 631 |
+
block_length: int = 32,
|
| 632 |
+
temperature: float = 0.0,
|
| 633 |
+
mask_token_id: Optional[int] = None,
|
| 634 |
+
eos_token_id: Optional[int] = None,
|
| 635 |
+
max_thinking_tokens: Optional[int] = None,
|
| 636 |
+
end_think_token_id: Optional[int] = None,
|
| 637 |
+
threshold: float = 0.0,
|
| 638 |
+
):
|
| 639 |
+
"""Linear speculative decoding: diffusion draft + AR verify.
|
| 640 |
+
|
| 641 |
+
Each iteration: (1) draft the next block under bidirectional attention,
|
| 642 |
+
(2) verify the drafted block under causal attention, accept the longest
|
| 643 |
+
prefix where draft matches AR + one bonus token, advance the KV cache.
|
| 644 |
+
|
| 645 |
+
LoRA-aware: when a PEFT adapter is attached to the model (e.g.
|
| 646 |
+
``linear_spec_lora``), it is toggled ON for the bidirectional draft
|
| 647 |
+
phase and OFF for the causal prefill / verify phases — so the adapter
|
| 648 |
+
only specializes the diffusion-mode forward and AR semantics are
|
| 649 |
+
preserved. With no adapter loaded the calls are no-ops.
|
| 650 |
+
|
| 651 |
+
Returns ``(output_ids, nfe)`` — ``output_ids`` includes the prompt.
|
| 652 |
+
"""
|
| 653 |
+
if prompt_ids.shape[0] != 1:
|
| 654 |
+
raise ValueError("Linear speculative decoding requires batch_size == 1")
|
| 655 |
+
|
| 656 |
+
token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
|
| 657 |
+
if eos_token_id is None:
|
| 658 |
+
eos_token_id = getattr(self.config, "eos_token_id", None)
|
| 659 |
+
|
| 660 |
+
device = prompt_ids.device
|
| 661 |
+
|
| 662 |
+
def _set_diffusion_lm(val: bool):
|
| 663 |
+
for layer in self.encoder.layers:
|
| 664 |
+
if hasattr(layer.self_attn, "diffusion_lm"):
|
| 665 |
+
layer.self_attn.diffusion_lm = val
|
| 666 |
+
|
| 667 |
+
def _toggle_adapters(enable: bool):
|
| 668 |
+
# No-op when no PEFT/LoRA modules are attached.
|
| 669 |
+
for module in self.modules():
|
| 670 |
+
if hasattr(module, "_disable_adapters"):
|
| 671 |
+
module._disable_adapters = not enable
|
| 672 |
+
|
| 673 |
+
# Prefill (causal, LoRA OFF).
|
| 674 |
+
_set_diffusion_lm(False)
|
| 675 |
+
_toggle_adapters(False)
|
| 676 |
+
enc_out = self.encoder(
|
| 677 |
+
input_ids=prompt_ids,
|
| 678 |
+
past_key_values=DynamicCache(),
|
| 679 |
+
use_cache=True,
|
| 680 |
+
use_causal_mask=True,
|
| 681 |
+
)
|
| 682 |
+
past_key_values = enc_out.past_key_values
|
| 683 |
+
last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
|
| 684 |
+
nfe = 1
|
| 685 |
+
|
| 686 |
+
if temperature > 0:
|
| 687 |
+
next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), num_samples=1)
|
| 688 |
+
else:
|
| 689 |
+
next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
|
| 690 |
+
|
| 691 |
+
if eos_token_id is not None and next_token.item() == eos_token_id:
|
| 692 |
+
return torch.cat([prompt_ids, next_token], dim=1), nfe
|
| 693 |
+
|
| 694 |
+
generated = [next_token]
|
| 695 |
+
total_gen = 1
|
| 696 |
+
|
| 697 |
+
while total_gen < max_new_tokens:
|
| 698 |
+
cache_len = past_key_values.get_seq_length()
|
| 699 |
+
|
| 700 |
+
block = torch.full((1, block_length), token_mask_id, dtype=torch.long, device=device)
|
| 701 |
+
block[0, 0] = next_token.item()
|
| 702 |
+
|
| 703 |
+
# Draft phase (bidirectional, LoRA ON) — iterate at threshold>0 so
|
| 704 |
+
# that even low-confidence blocks make progress.
|
| 705 |
+
_set_diffusion_lm(True)
|
| 706 |
+
_toggle_adapters(True)
|
| 707 |
+
while True:
|
| 708 |
+
is_mask = block == token_mask_id
|
| 709 |
+
if not is_mask.any():
|
| 710 |
+
break
|
| 711 |
+
|
| 712 |
+
enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=False)
|
| 713 |
+
nfe += 1
|
| 714 |
+
|
| 715 |
+
draft_logits = self.diffusion_head(enc_out.last_hidden_state)
|
| 716 |
+
# LLaDA: logit[i] directly predicts position i — no shift needed.
|
| 717 |
+
|
| 718 |
+
if temperature > 0:
|
| 719 |
+
draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
|
| 720 |
+
draft_tokens = torch.multinomial(
|
| 721 |
+
draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1
|
| 722 |
+
).view(1, block_length)
|
| 723 |
+
else:
|
| 724 |
+
draft_tokens = draft_logits.argmax(dim=-1)
|
| 725 |
+
draft_probs = torch.softmax(draft_logits, dim=-1)
|
| 726 |
+
|
| 727 |
+
if threshold > 0:
|
| 728 |
+
draft_conf = torch.gather(draft_probs, -1, draft_tokens.unsqueeze(-1)).squeeze(-1)
|
| 729 |
+
draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
|
| 730 |
+
unmask = draft_conf >= threshold
|
| 731 |
+
# Force progress even when every masked position is below threshold.
|
| 732 |
+
if not unmask.any():
|
| 733 |
+
best_idx = draft_conf.view(-1).argmax()
|
| 734 |
+
unmask = torch.zeros_like(is_mask, dtype=torch.bool)
|
| 735 |
+
unmask.view(-1)[best_idx] = True
|
| 736 |
+
block[unmask] = draft_tokens[unmask]
|
| 737 |
+
else:
|
| 738 |
+
block[is_mask] = draft_tokens[is_mask]
|
| 739 |
+
break
|
| 740 |
+
|
| 741 |
+
# Verify phase (causal, LoRA OFF).
|
| 742 |
+
_set_diffusion_lm(False)
|
| 743 |
+
_toggle_adapters(False)
|
| 744 |
+
enc_out = self.encoder(
|
| 745 |
+
input_ids=block,
|
| 746 |
+
past_key_values=past_key_values,
|
| 747 |
+
use_cache=True,
|
| 748 |
+
use_causal_mask=True,
|
| 749 |
+
)
|
| 750 |
+
past_key_values = enc_out.past_key_values
|
| 751 |
+
nfe += 1
|
| 752 |
+
|
| 753 |
+
verify_logits = self.diffusion_head(enc_out.last_hidden_state)
|
| 754 |
+
if temperature > 0:
|
| 755 |
+
ar_tokens = torch.multinomial(
|
| 756 |
+
torch.softmax(verify_logits / temperature, dim=-1).view(-1, verify_logits.shape[-1]),
|
| 757 |
+
num_samples=1,
|
| 758 |
+
).view(1, block_length)
|
| 759 |
+
else:
|
| 760 |
+
ar_tokens = verify_logits.argmax(dim=-1)
|
| 761 |
+
|
| 762 |
+
# Accept consecutive matches; AR also gives one bonus token at the tail.
|
| 763 |
+
accepted = 0
|
| 764 |
+
for i in range(block_length - 1):
|
| 765 |
+
if ar_tokens[0, i].item() == block[0, i + 1].item():
|
| 766 |
+
accepted += 1
|
| 767 |
+
else:
|
| 768 |
+
break
|
| 769 |
+
accepted += 1
|
| 770 |
+
|
| 771 |
+
accepted_toks = ar_tokens[:, :accepted]
|
| 772 |
+
generated.append(accepted_toks)
|
| 773 |
+
total_gen += accepted
|
| 774 |
+
|
| 775 |
+
_crop_dynamic_cache(past_key_values, cache_len + accepted)
|
| 776 |
+
next_token = ar_tokens[:, accepted - 1 : accepted]
|
| 777 |
+
|
| 778 |
+
if eos_token_id is not None:
|
| 779 |
+
eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
|
| 780 |
+
if len(eos_pos) > 0:
|
| 781 |
+
first_eos = eos_pos[0].item()
|
| 782 |
+
generated[-1] = accepted_toks[:, : first_eos + 1]
|
| 783 |
+
total_gen = total_gen - accepted + first_eos + 1
|
| 784 |
+
break
|
| 785 |
+
|
| 786 |
+
# Thinking-budget enforcement: force end-think as next seed if budget exhausted.
|
| 787 |
+
if end_think_token_id is not None and max_thinking_tokens is not None:
|
| 788 |
+
if total_gen > max_thinking_tokens:
|
| 789 |
+
all_gen = torch.cat(generated, dim=1)
|
| 790 |
+
if not (all_gen == end_think_token_id).any():
|
| 791 |
+
next_token = torch.tensor([[end_think_token_id]], device=device)
|
| 792 |
+
|
| 793 |
+
if total_gen >= max_new_tokens:
|
| 794 |
+
break
|
| 795 |
+
|
| 796 |
+
all_generated = torch.cat(generated, dim=1)
|
| 797 |
+
output_ids = torch.cat([prompt_ids, all_generated], dim=1)
|
| 798 |
+
return output_ids, nfe
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
# ─── Module-level helpers used by `generate` and `linear_spec_generate` ──
|
| 802 |
+
|
| 803 |
+
def _crop_dynamic_cache(past_key_values: DynamicCache, max_length: int):
|
| 804 |
+
"""Crop a DynamicCache to max_length, compatible with both old and new transformers."""
|
| 805 |
+
if hasattr(past_key_values, 'crop'):
|
| 806 |
+
past_key_values.crop(max_length)
|
| 807 |
+
else:
|
| 808 |
+
for layer_idx in range(len(past_key_values)):
|
| 809 |
+
past_key_values.key_cache[layer_idx] = past_key_values.key_cache[layer_idx][:, :, :max_length]
|
| 810 |
+
past_key_values.value_cache[layer_idx] = past_key_values.value_cache[layer_idx][:, :, :max_length]
|
| 811 |
+
past_key_values._seen_tokens = max_length
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
def _add_gumbel_noise(logits, temperature):
|
| 815 |
+
"""Gumbel-max sampling in float64 (low-precision Gumbel hurts MDM quality)."""
|
| 816 |
+
if temperature == 0:
|
| 817 |
+
return logits
|
| 818 |
+
logits = logits.to(torch.float64)
|
| 819 |
+
noise = torch.rand_like(logits, dtype=torch.float64)
|
| 820 |
+
gumbel_noise = (- torch.log(noise)) ** temperature
|
| 821 |
+
return logits.exp() / gumbel_noise
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
def _get_num_transfer_tokens(mask_index, steps: int):
|
| 825 |
+
"""Even split of masked positions across `steps`, with remainder front-loaded."""
|
| 826 |
+
mask_num = mask_index.sum(dim=1, keepdim=True)
|
| 827 |
+
base = mask_num // steps
|
| 828 |
+
remainder = mask_num % steps
|
| 829 |
+
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
|
| 830 |
+
for i in range(mask_num.size(0)):
|
| 831 |
+
num_transfer_tokens[i, : int(remainder[i])] += 1
|
| 832 |
+
return num_transfer_tokens
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
def _get_transfer_index(logits, temperature, mask_index, x, num_transfer_tokens, threshold=None):
|
| 836 |
+
"""Pick which masked positions to commit this denoising step.
|
| 837 |
+
|
| 838 |
+
Returns (x0, transfer_index): x0 is argmax tokens (clamped to original x at
|
| 839 |
+
non-masked positions); transfer_index is a bool mask over positions to
|
| 840 |
+
finalize, chosen by top-k confidence (and filtered by `threshold` if given).
|
| 841 |
+
"""
|
| 842 |
+
logits_with_noise = _add_gumbel_noise(logits, temperature=temperature)
|
| 843 |
+
x0 = torch.argmax(logits_with_noise, dim=-1)
|
| 844 |
+
|
| 845 |
+
p = F.softmax(logits, dim=-1)
|
| 846 |
+
x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)
|
| 847 |
+
|
| 848 |
+
x0 = torch.where(mask_index, x0, x)
|
| 849 |
+
confidence = torch.where(mask_index, x0_p, -np.inf)
|
| 850 |
+
|
| 851 |
+
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
|
| 852 |
+
if threshold is not None:
|
| 853 |
+
num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
|
| 854 |
+
for j in range(confidence.shape[0]):
|
| 855 |
+
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
|
| 856 |
+
transfer_index[j, select_index] = True
|
| 857 |
+
if threshold is not None:
|
| 858 |
+
for k in range(1, num_transfer_tokens[j]):
|
| 859 |
+
if confidence[j, select_index[k]] < threshold:
|
| 860 |
+
transfer_index[j, select_index[k]] = False
|
| 861 |
+
return x0, transfer_index
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
|
| 865 |
+
"""Return a Bool mask of length len(log_w) with exactly k True."""
|
| 866 |
+
g = -torch.log(-torch.log(torch.rand_like(log_w) + 1e-9) + 1e-9)
|
| 867 |
+
topk = torch.topk(log_w + g, k).indices
|
| 868 |
+
mask = torch.zeros_like(log_w, dtype=torch.bool)
|
| 869 |
+
mask[topk] = True
|
| 870 |
+
return mask
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"unk_token": {
|
| 17 |
+
"content": "<unk>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
}
|
| 23 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3277c00fe5fb3963b3cb7c07b7f183722d2af4d775a4aea7cfb3684d7cccbc2f
|
| 3 |
+
size 17078330
|
tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|