Clean up rope params; ensure transformers 4.55/5.0 compatibility

#2
by abhgarg - opened
.gitattributes CHANGED
@@ -34,8 +34,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
- assets/demo.gif filter=lfs diff=lfs merge=lfs -text
38
- assets/demo.mp4 filter=lfs diff=lfs merge=lfs -text
39
- assets/result_acc.png filter=lfs diff=lfs merge=lfs -text
40
- assets/result_efficiency.png filter=lfs diff=lfs merge=lfs -text
41
- assets/teaser.png filter=lfs diff=lfs merge=lfs -text
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
README.md CHANGED
@@ -1,160 +1,160 @@
1
  ---
2
  library_name: transformers
3
- license: other
4
- license_name: nvidia-nemotron-open-model-license
5
- license_link: >-
6
- https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-nemotron-open-model-license/
7
- pipeline_tag: text-generation
8
- tags:
9
- - nvidia
10
- - pytorch
11
  ---
12
 
13
- # Nemotron-Labs-Diffusion-3B
14
 
 
15
 
16
- <div align="center" style="line-height: 1;">
17
- <a href="https://d1qx31qr3h6wln.cloudfront.net/publications/Nemotron_Diffusion_Tech_Report_v1.pdf?VersionId=db8_EMO8B.vmU26.jr7Le9pN3MqcUDNL" target="_blank" style="margin: 2px;">
18
- <img alt="Chat" src="https://img.shields.io/badge/📝Paper-Read Now!-536af5?color=76B900&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
19
- </a>
20
- <a href="https://huggingface.co/collections/nvidia/nemotron-labs-diffusion" target="_blank" style="margin: 2px;">
21
- <img alt="Nemotron-Labs-Diffusion Model Family" src="https://img.shields.io/badge/%F0%9F%A4%97-Nemotron--Labs--Diffusion_Model_Family-76B900" style="display: inline-block; vertical-align: middle;"/>
22
- </a>
23
- <a href="https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-nemotron-open-model-license/" style="margin: 2px;">
24
- <img alt="License" src="https://img.shields.io/badge/License-NVIDIA Open Model License-f5de53?&color=f5de53" style="display: inline-block; vertical-align: middle;"/>
25
- </a>
26
- </div>
27
 
 
28
 
29
- [![Demo](./assets/demo.gif)](./assets/demo.mp4)
30
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- ## Model Overview
 
 
 
 
33
 
34
- Nemotron-Labs-Diffusion is a tri-mode language model that supports both AR decoding and diffusion-based parallel decoding by simply switching the attention pattern of the same model during inference. The synergy between these two modes enables a third mode, called self-speculation: the same model performs diffusion-based parallel drafting and AR verification with shared KV cache, achieving high acceptance lengths and decoding efficiency. The seamless mode switching by simply changing attention patterns enables high efficiency at different concurrency levels in varying deployment scenarios with one single model.
35
 
36
- <div align="center">
37
- <img src="./assets/teaser.png" alt="An illustration of Tri-Mode LMs" width="500">
38
- </div>
39
 
 
 
 
40
 
41
- ## Highlights
 
 
 
42
 
43
- - SOTA 3B, 8B, 14B dense LM family (base, instruct, and vision-language variants) supporting AR, diffusion, and self-speculation with the focus on decode efficiency.
44
- - Generation moved from a memory-bound regime toward a compute-bound regime. Model weights are loaded once and reused to compute multiple tokens during generation.
45
- - Self-speculation uses diffusion for drafting and AR for verification, providing a stronger alternative to MTP approaches:
46
- * 3x higher acceptance length and 2.2x speed-up vs. Qwen3-8B-Eagle3 in SGLang.
47
- * 5.9× tokens per forward over Qwen3-8B (no MTP) with the same accuracy.
48
- - Real-device speed-up across platforms:
49
- * DGX Spark (8B, concurrency 1): 2.7x faster with 112 tok/sec vs. 41.8 tok/sec AR using w4a16.
50
- * GB200 (8B, concurrency 1): 3.3x faster with 850 tok/sec vs. 253 tok/sec AR and 360 tok/sec Eagle3. Custom CUDA kernels boost to 1015 tok/sec (4x).
51
- - Diffusion speedup-of-light analysis shows that throughput can be further doubled (vs. current best) for a single user with better sampling - future research.
52
 
 
53
 
54
- <div align="center">
55
- <img src="./assets/result_acc.png" alt="Efficiency Results" width="800">
56
- </div>
57
 
58
- <div align="center">
59
- <img src="./assets/result_efficiency.png" alt="Acc Results" width="800">
60
- </div>
61
 
 
62
 
63
- ## License/Terms of Use
 
 
64
 
65
- Use of this model is governed by the [NVIDIA Nemotron Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-nemotron-open-model-license/).
66
 
 
 
67
 
68
- ## Environment
 
 
69
 
70
- ```bash
71
- transformers>=5.0.0
 
72
  ```
73
 
74
- ## Chat with Our Model
75
 
 
76
 
77
  ```
78
- from transformers import AutoModel, AutoTokenizer
79
  import torch
80
 
81
- repo_name = "nvidia/Nemotron-Labs-Diffusion-3B"
82
 
83
  tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
84
- model = AutoModel.from_pretrained(repo_name, trust_remote_code=True)
85
- model = model.cuda().to(torch.bfloat16)
 
 
 
86
 
87
  history = []
88
 
89
  user_input = input("User: ").strip()
90
  history.append({"role": "user", "content": user_input})
91
 
92
- prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
93
- prompt_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device='cuda')
94
 
95
- ## Chat in AR Mode
96
- out_ids, nfe = model.ar_generate(inputs.input_ids, max_new_tokens=512)
97
 
98
- ## Chat in dLM Mode
99
- out_ids, nfe = model.generate(prompt_ids, max_new_tokens=512, block_length=32, threshold=0.9, eos_token_id=tokenizer.eos_token_id)
100
 
101
- ## Chat in Linear Self-Speculation Mode
102
- out_ids, nfe = model.linear_spec_generate(prompt_ids, max_new_tokens=512, block_length=32, eos_token_id=tokenizer.eos_token_id)
103
 
104
- tokenized_out = tokenizer.batch_decode(out_ids[:, prompt_ids.shape[1]:], skip_special_tokens=True)[0]
105
  print(f"Model: {tokenized_out}")
106
  print(f"[Num Function Eval (NFE)={nfe}]")
107
  ```
108
 
 
109
 
 
 
 
110
 
111
- ## Inference with Linear Self-Speculation + LoRA-enhanced Drafter
112
 
113
- An optional LoRA adatper can be applied to the diffusion drafter in the linear self-speculation mode to further increase the acceptance length:
 
 
114
 
 
115
 
116
- ```python
117
- import torch
118
- from transformers import AutoModel, AutoTokenizer
119
- from peft import PeftModel
120
 
121
- repo = "nvidia/Nemotron-Labs-Diffusion-3B"
122
- tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True)
123
- model = AutoModel.from_pretrained(repo, trust_remote_code=True)
124
- model = model.cuda().to(torch.bfloat16)
125
 
126
- # Attach the linear_spec LoRA adapter.
127
- model = PeftModel.from_pretrained(model, repo, subfolder="linear_spec_lora").eval()
128
- # Unwrap so we can call linear_spec_generate directly (it toggles LoRA internally).
129
- base = model.model
 
 
 
130
 
131
- history = [{"role": "user", "content": "Solve: What is 15% of 240?"}]
132
- prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
133
- prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
134
-
135
- out_ids, nfe = base.linear_spec_generate(
136
- prompt_ids, max_new_tokens=512, block_length=32,
137
- eos_token_id=tokenizer.eos_token_id,
138
- )
139
- print(tokenizer.decode(out_ids[0, prompt_ids.shape[1]:], skip_special_tokens=True))
140
- print(f"[NFE={nfe}]")
141
  ```
 
 
142
 
 
143
 
144
- ## Ethical Considerations
145
- NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. For more detailed information on ethical considerations for this model, please see the [bias](./model_cards/bias.md), [explainability](./model_cards/explainability.md), [safety & security](./model_cards/safety.md), and [privacy](./model_cards/privacy.md) subcards.
 
146
 
147
- Please report model quality, risk, security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).
148
 
 
 
149
 
150
- ## Citations
 
 
151
 
152
- ```bibtex
153
- @techreport{fu2026nemotronlabsdiffusion,
154
- title = {Nemotron-Labs-Diffusion: A Tri-Mode Language Model Unifying Autoregressive, Diffusion, and Self-Speculation Decoding},
155
- author = {Yonggan Fu and Lexington Whalen and Abhinav Garg and Chengyue Wu and Maksim Khadkevich and Nicolai Oswald and Enze Xie and Daniel Egert and Sharath Turuvekere Sreenivas and Shizhe Diao and Chenhan Yu and Ye Yu and Weijia Chen and Sajad Norouzi and Shiyi Lan and Ligeng Zhu and Jin Wang and Jindong Jiang and Morteza Mardani and Mehran Maghoumi and Song Han and Ante Jukic and Nima Tajbakhsh and Jan Kautz and Pavlo Molchanov},
156
- institution = {NVIDIA},
157
- year = {2026},
158
- note = {Technical report}
159
- }
160
  ```
 
 
1
  ---
2
  library_name: transformers
3
+ tags: []
 
 
 
 
 
 
 
4
  ---
5
 
6
+ # Nemotron-Diffusion-Exp-Ministral-3B-Instruct
7
 
8
+ Developed by [DLER team](https://nv-dler.github.io/) @ NVR and will be updated actively. Contact Yonggan Fu and Pavlo Molchanov for any question.
9
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Environment
12
 
13
+ Docker path: `/lustre/fsw/portfolios/nvr/users/yongganf/docker/megatron_py25_dllm_ministral.sqsh` on CW-DFW. Apply for interactive nodes with the following command:
14
 
15
+ ```
16
+ srun -A {account} --partition interactive --time 4:00:00 --gpus 8 --container-image /lustre/fsw/portfolios/nvr/users/yongganf/docker/megatron_py25_dllm_ministral.sqsh --container-mounts=$HOME:/home,/lustre:/lustre --pty bash
17
+ ```
18
+
19
+ ## Chat with Our Model in dLM Mode
20
+
21
+
22
+ ```
23
+ from transformers import AutoModel, AutoTokenizer
24
+ import torch
25
 
26
+ repo_name = "nvidia/Nemotron-Diffusion-Exp-Ministral-3B-Instruct"
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
29
+ model = AutoModel.from_pretrained(repo_name, trust_remote_code=True)
30
+ model = model.cuda().to(torch.bfloat16)
31
 
32
+ history = []
33
 
34
+ user_input = input("User: ").strip()
35
+ history.append({"role": "user", "content": user_input})
 
36
 
37
+ prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
38
+ prompt_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device='cuda')
39
+ out_ids, nfe = model.generate(prompt_ids, max_new_tokens=512, steps=512, block_length=32, shift_logits=False, causal_context=True, threshold=0.9, eos_token_id=tokenizer.eos_token_id)
40
 
41
+ tokenized_out = tokenizer.batch_decode(out_ids[:, prompt_ids.shape[1]:], skip_special_tokens=True)[0]
42
+ print(f"Model: {tokenized_out}")
43
+ print(f"[Num Function Eval (NFE)={nfe}]")
44
+ ```
45
 
 
 
 
 
 
 
 
 
 
46
 
47
+ ## Chat with Our Model in AR Mode
48
 
 
 
 
49
 
50
+ ```
51
+ from transformers import AutoModel, AutoTokenizer
52
+ import torch
53
 
54
+ repo_name = "nvidia/Nemotron-Diffusion-Exp-Ministral-3B-Instruct"
55
 
56
+ tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
57
+ model = AutoModel.from_pretrained(repo_name, trust_remote_code=True)
58
+ model = model.cuda().to(torch.bfloat16)
59
 
60
+ history = []
61
 
62
+ user_input = input("User: ").strip()
63
+ history.append({"role": "user", "content": user_input})
64
 
65
+ prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True, enable_thinking=False)
66
+ prompt_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device='cuda')
67
+ out_ids, nfe = model.ar_generate(inputs.input_ids, max_new_tokens=512)
68
 
69
+ tokenized_out = tokenizer.batch_decode(out_ids[:, prompt_ids.shape[1]:], skip_special_tokens=True)[0]
70
+ print(f"Model: {tokenized_out}")
71
+ print(f"[Num Function Eval (NFE)={nfe}]")
72
  ```
73
 
 
74
 
75
+ ## Chat with Our Model in Quadratic Self-Speculation Mode
76
 
77
  ```
78
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
79
  import torch
80
 
81
+ repo_name = "nvidia/Nemotron-Diffusion-Exp-Ministral-3B-Instruct"
82
 
83
  tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
84
+
85
+ config = AutoConfig.from_pretrained(repo_name, trust_remote_code=True)
86
+ config.enable_self_spec = True
87
+
88
+ model = AutoModel.from_pretrained(repo_name, config=config, trust_remote_code=True).cuda().to(torch.bfloat16)
89
 
90
  history = []
91
 
92
  user_input = input("User: ").strip()
93
  history.append({"role": "user", "content": user_input})
94
 
95
+ prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True, enable_thinking=False)
 
96
 
97
+ inputs = tokenizer(prompt, return_tensors="pt")
98
+ inputs = inputs.to("cuda")
99
 
100
+ out_ids, nfe = model.self_spec_generate(inputs.input_ids, max_new_tokens=512, steps=512, block_length=32, ar_mix_weight=0.5, eos_token_id=tokenizer.eos_token_id)
 
101
 
102
+ tokenized_out = tokenizer.batch_decode(out_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
 
103
 
 
104
  print(f"Model: {tokenized_out}")
105
  print(f"[Num Function Eval (NFE)={nfe}]")
106
  ```
107
 
108
+ ## Chat with Our Model in Linear Self-Speculation Mode
109
 
110
+ ```
111
+ from transformers import AutoModel, AutoTokenizer
112
+ import torch
113
 
114
+ repo_name = "nvidia/Nemotron-Diffusion-Exp-Ministral-3B-Instruct"
115
 
116
+ tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
117
+ model = AutoModel.from_pretrained(repo_name, trust_remote_code=True)
118
+ model = model.cuda().to(torch.bfloat16)
119
 
120
+ history = []
121
 
122
+ user_input = input("User: ").strip()
123
+ history.append({"role": "user", "content": user_input})
 
 
124
 
125
+ prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True, enable_thinking=False)
126
+ prompt_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device='cuda')
127
+ out_ids, nfe = model.linear_spec_generate(prompt_ids, max_new_tokens=512, block_length=32, eos_token_id=tokenizer.eos_token_id)
 
128
 
129
+ tokenized_out = tokenizer.batch_decode(out_ids[:, prompt_ids.shape[1]:], skip_special_tokens=True)[0]
130
+ print(f"Model: {tokenized_out}")
131
+ print(f"[Num Function Eval (NFE)={nfe}]")
132
+ ```
133
+
134
+
135
+ ## Chat with Our Model in Linear Decoding Mode with Multi-Path Verification
136
 
 
 
 
 
 
 
 
 
 
 
137
  ```
138
+ from transformers import AutoModel, AutoTokenizer
139
+ import torch
140
 
141
+ repo_name = "nvidia/Nemotron-Diffusion-Exp-Ministral-3B-Instruct"
142
 
143
+ tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
144
+ model = AutoModel.from_pretrained(repo_name, trust_remote_code=True)
145
+ model = model.cuda().to(torch.bfloat16)
146
 
147
+ history = []
148
 
149
+ user_input = input("User: ").strip()
150
+ history.append({"role": "user", "content": user_input})
151
 
152
+ prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True, enable_thinking=False)
153
+ prompt_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device='cuda')
154
+ out_ids, nfe = model.linear_spec_generate_mp(prompt_ids, max_new_tokens=512, block_length=32, eos_token_id=tokenizer.eos_token_id)
155
 
156
+ tokenized_out = tokenizer.batch_decode(out_ids[:, prompt_ids.shape[1]:], skip_special_tokens=True)[0]
157
+ print(f"Model: {tokenized_out}")
158
+ print(f"[Num Function Eval (NFE)={nfe}]")
 
 
 
 
 
159
  ```
160
+
assets/demo.gif DELETED

Git LFS Details

  • SHA256: 0d09264e272ac0f82dee36417f6a16511287ec1f8dee3b5dba3da222d791fd2c
  • Pointer size: 132 Bytes
  • Size of remote file: 8.25 MB
assets/demo.mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:666d8785ac4af75931d9c677757c4ef9945bf114d07f1c4e2ebb7b893ac39006
3
- size 9454873
 
 
 
 
assets/result_acc.png DELETED

Git LFS Details

  • SHA256: 992aa22ca9eca3d0bddbcd9f49837e2a9f377bbc0f7545563b129a50b3811448
  • Pointer size: 131 Bytes
  • Size of remote file: 405 kB
assets/result_efficiency.png DELETED

Git LFS Details

  • SHA256: 4f6161912e2aa703e0ef1bdccbb85039529b97e759d6247c33afa2a209806ede
  • Pointer size: 131 Bytes
  • Size of remote file: 801 kB
assets/teaser.png DELETED

Git LFS Details

  • SHA256: 6c94aa7b0c6cf8fb739724d0c1ce45749c76443c592eeab94d7cbb9083c6c6b1
  • Pointer size: 131 Bytes
  • Size of remote file: 581 kB
chat_utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def add_gumbel_noise(logits, temperature):
7
+ '''
8
+ The Gumbel max is a method for sampling categorical distributions.
9
+ According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
10
+ Thus, we use float64.
11
+ '''
12
+ if temperature == 0:
13
+ return logits
14
+ logits = logits.to(torch.float64)
15
+ noise = torch.rand_like(logits, dtype=torch.float64)
16
+ gumbel_noise = (- torch.log(noise)) ** temperature
17
+ return logits.exp() / gumbel_noise
18
+
19
+
20
+ def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False):
21
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
22
+ x0 = torch.argmax(logits_with_noise, dim=-1)
23
+
24
+ if remasking == 'low_confidence':
25
+ # p = F.softmax(logits.to(torch.float64), dim=-1)
26
+ p = F.softmax(logits, dim=-1)
27
+ x0_p = torch.squeeze(
28
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
29
+ elif remasking == 'top_p_margin':
30
+ # Compute probabilities
31
+ p = F.softmax(logits, dim=-1) # (B, L, V)
32
+ # Top-2 per position
33
+ top2 = torch.topk(p, k=2, dim=-1).values # (B, L, 2)
34
+ margin = top2[..., 0] - top2[..., 1] # (B, L)
35
+
36
+ # Normalize margin to [0,1] over MASKED positions per row
37
+ plus_inf = torch.full_like(margin, float('inf'))
38
+ minus_inf = torch.full_like(margin, float('-inf'))
39
+ masked_for_min = torch.where(mask_index, margin, plus_inf)
40
+ masked_for_max = torch.where(mask_index, margin, minus_inf)
41
+ row_min = masked_for_min.amin(dim=1, keepdim=True) # (B, 1)
42
+ row_max = masked_for_max.amax(dim=1, keepdim=True) # (B, 1)
43
+ denom = (row_max - row_min)
44
+
45
+ # If denom==0 (all equal), set normalized=1 on masked; 0 elsewhere by default
46
+ normalized = torch.zeros_like(margin)
47
+ nonzero = denom > 0
48
+ normalized = torch.where(
49
+ mask_index & nonzero,
50
+ (margin - row_min) / (denom + 1e-12),
51
+ normalized
52
+ )
53
+ normalized = torch.where(
54
+ mask_index & (~nonzero),
55
+ torch.ones_like(normalized),
56
+ normalized
57
+ )
58
+ x0_p = normalized # ∈ [0,1] on masked positions
59
+ elif remasking == 'random':
60
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
61
+ else:
62
+ raise NotImplementedError(remasking)
63
+
64
+ # Calculate negative entropy if requested
65
+ if neg_entropy:
66
+ # p = F.softmax(logits.to(torch.float64), dim=-1)
67
+ p = F.softmax(logits, dim=-1)
68
+ epsilon = 1e-10
69
+ log_probs = torch.log(p + epsilon)
70
+ confidence_scores = torch.sum(p * log_probs, dim=-1) # negative entropy per position
71
+ else:
72
+ confidence_scores = x0_p
73
+
74
+ x0 = torch.where(mask_index, x0, x)
75
+ confidence = torch.where(mask_index, confidence_scores, -np.inf)
76
+
77
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
78
+ if threshold is not None:
79
+ num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
80
+ # print(f'confidence: {confidence}')
81
+ for j in range(confidence.shape[0]):
82
+ _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
83
+ transfer_index[j, select_index] = True
84
+ if threshold is not None:
85
+ for k in range(1, num_transfer_tokens[j]):
86
+ if confidence[j, select_index[k]] < threshold:
87
+ transfer_index[j, select_index[k]] = False
88
+ return x0, transfer_index
89
+
90
+
91
+ def get_num_transfer_tokens(mask_index, steps: int):
92
+ mask_num = mask_index.sum(dim=1, keepdim=True)
93
+ base = mask_num // steps
94
+ remainder = mask_num % steps
95
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
96
+ for i in range(mask_num.size(0)):
97
+ num_transfer_tokens[i, : int(remainder[i])] += 1
98
+ return num_transfer_tokens
99
+
100
+
101
+ @torch.no_grad()
102
+ def generate_with_prefix_cache_block_diff(
103
+ model,
104
+ prompt,
105
+ steps=128,
106
+ gen_length=128,
107
+ block_length=128,
108
+ temperature=0.,
109
+ remasking='low_confidence',
110
+ mask_id=126336,
111
+ threshold=None,
112
+ factor=None,
113
+ shift_logits=False,
114
+ neg_entropy=False,
115
+ causal_context=False,
116
+ eos_token_id=None,
117
+ max_thinking_tokens=None,
118
+ end_think_token_id=None,
119
+ ):
120
+ dream_style=shift_logits
121
+ x_accum = prompt.clone()
122
+ B = prompt.shape[0]
123
+
124
+ assert gen_length % block_length == 0
125
+ num_blocks = gen_length // block_length
126
+
127
+ assert steps % num_blocks == 0
128
+ steps_per_block = steps // num_blocks
129
+
130
+ nfe = 0
131
+
132
+ if causal_context:
133
+ model_module = model.module if hasattr(model, "module") else model
134
+ for layer in model_module.encoder.layers:
135
+ if hasattr(layer.self_attn, 'diffusion_lm'):
136
+ layer.self_attn.diffusion_lm=False
137
+
138
+ # Compute KV cache for the prompt initially
139
+ output = model(prompt, use_cache=True, use_causal_mask=causal_context)
140
+ past_key_values = output.past_key_values
141
+
142
+ if causal_context:
143
+ for layer in model_module.encoder.layers:
144
+ if hasattr(layer.self_attn, 'diffusion_lm'):
145
+ layer.self_attn.diffusion_lm=True
146
+
147
+ # Causal prefill: next token from last position (same as linear_spec_generate).
148
+ next_token = None
149
+ if causal_context:
150
+ last_logit = output.logits[:, -1, :]
151
+ if temperature > 0:
152
+ probs = torch.softmax(last_logit / temperature, dim=-1)
153
+ next_token = torch.multinomial(probs, num_samples=1)
154
+ else:
155
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
156
+
157
+ # For dream_style: store the "next token logit" of the context
158
+ next_logits_context = None
159
+ if dream_style:
160
+ next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
161
+
162
+ for num_block in range(num_blocks):
163
+ # Create a new block with mask tokens; under causal context, seed position 0
164
+ # with the next-token prediction from the previous causal forward (prefill or
165
+ # post-block encode), matching linear_spec_generate.
166
+ mask_block = torch.ones(
167
+ (prompt.shape[0], block_length),
168
+ dtype=prompt.dtype,
169
+ device=prompt.device
170
+ ) * mask_id
171
+ if causal_context:
172
+ mask_block[:, 0] = next_token[:, 0]
173
+
174
+ # Append the block of masks
175
+ x_accum = torch.cat([x_accum, mask_block], dim=1)
176
+ current_block_start = prompt.size(1) + num_block * block_length
177
+ block_slice = slice(current_block_start, current_block_start + block_length)
178
+
179
+ # ---- thinking budget enforcement ----
180
+ # If we've generated >= max_thinking_tokens without a </think>, inject one.
181
+ if end_think_token_id is not None and max_thinking_tokens is not None:
182
+ tokens_before_block = num_block * block_length
183
+ tokens_after_block = tokens_before_block + block_length
184
+ if tokens_after_block > max_thinking_tokens:
185
+ gen_so_far = x_accum[:, prompt.size(1):current_block_start]
186
+ has_end_think = (
187
+ (gen_so_far == end_think_token_id).any(dim=1)
188
+ if gen_so_far.size(1) > 0
189
+ else torch.zeros(B, dtype=torch.bool, device=prompt.device)
190
+ )
191
+ if not has_end_think.all():
192
+ if tokens_before_block < max_thinking_tokens:
193
+ offset = max_thinking_tokens - tokens_before_block
194
+ else:
195
+ offset = 0
196
+ inject_pos = current_block_start + offset
197
+ for b in range(B):
198
+ if not has_end_think[b]:
199
+ x_accum[b, inject_pos] = end_think_token_id
200
+
201
+ # Build the initial mask for this block
202
+ mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb)
203
+
204
+ # Precompute the transfer schedule for this block
205
+ if dream_style:
206
+ # masked positions only (position 0 may be causal-seeded, not mask_id)
207
+ schedule_mask = mask_block_idx0
208
+ else:
209
+ schedule_mask = mask_block_idx0
210
+
211
+ num_transfer_tokens = get_num_transfer_tokens(schedule_mask, steps_per_block) # (B, steps)
212
+
213
+ # Denoise the current block
214
+ for i in range(steps_per_block):
215
+ mask_block_idx = (x_accum[:, block_slice] == mask_id) # (B, Lb)
216
+ if mask_block_idx.sum() == 0:
217
+ break
218
+
219
+ nfe += 1
220
+
221
+ # Forward only the current noisy block using cached context
222
+ logits_block = model(
223
+ x_accum[:, block_slice],
224
+ past_key_values=past_key_values,
225
+ use_cache=False
226
+ ).logits
227
+
228
+ if dream_style:
229
+ # Align logits so that each masked position has a predictor:
230
+ # prepend context-next logit, then use logits_block[:-1]
231
+ if block_length == 1:
232
+ logits_use = next_logits_context # (B, 1, V)
233
+ else:
234
+ logits_use = torch.cat(
235
+ [next_logits_context, logits_block[:, :-1, :]],
236
+ dim=1
237
+ ) # (B, Lb, V)
238
+
239
+ mask_use = mask_block_idx # (B, Lb)
240
+ x_use = x_accum[:, block_slice] # (B, Lb)
241
+
242
+ x0, transfer_idx = get_transfer_index(
243
+ logits_use, temperature, remasking, mask_use, x_use,
244
+ num_transfer_tokens=num_transfer_tokens[:, i],
245
+ threshold=threshold, neg_entropy=neg_entropy
246
+ )
247
+ cur = x_accum[:, block_slice].clone()
248
+ cur[transfer_idx] = x0[transfer_idx]
249
+ x_accum[:, block_slice] = cur
250
+
251
+ else:
252
+ # non-AR (same-position) case
253
+ x0, transfer_idx = get_transfer_index(
254
+ logits_block, temperature, remasking, mask_block_idx,
255
+ x_accum[:, block_slice],
256
+ num_transfer_tokens=num_transfer_tokens[:, i],
257
+ threshold=threshold, neg_entropy=neg_entropy
258
+ )
259
+ cur = x_accum[:, block_slice].clone()
260
+ cur[transfer_idx] = x0[transfer_idx]
261
+ x_accum[:, block_slice] = cur
262
+
263
+ if eos_token_id is not None:
264
+ block_tokens = x_accum[:, block_slice] # (B, Lb)
265
+ eos_mask = (block_tokens == eos_token_id) # (B, Lb)
266
+ any_eos = eos_mask.any(dim=1) # (B,)
267
+ if any_eos.any():
268
+ after_eos = eos_mask.cumsum(dim=1).bool() # (B, Lb)
269
+ mask_before = (block_tokens == mask_id) & ~after_eos
270
+ if (any_eos & ~mask_before.any(dim=1)).any():
271
+ break
272
+
273
+ if causal_context:
274
+ for layer in model_module.encoder.layers:
275
+ if hasattr(layer.self_attn, 'diffusion_lm'):
276
+ layer.self_attn.diffusion_lm=False
277
+
278
+ # after block is fully denoised, update KV cache
279
+ output = model(
280
+ x_accum[:, block_slice],
281
+ past_key_values=past_key_values,
282
+ use_cache=True,
283
+ use_causal_mask=causal_context
284
+ )
285
+ past_key_values = output.past_key_values
286
+ nfe += 1
287
+
288
+ if causal_context:
289
+ for layer in model_module.encoder.layers:
290
+ if hasattr(layer.self_attn, 'diffusion_lm'):
291
+ layer.self_attn.diffusion_lm=True
292
+ # Next block's first position = greedy/sampled next token from this causal encode
293
+ last_logit = output.logits[:, -1, :]
294
+ if temperature > 0:
295
+ probs = torch.softmax(last_logit / temperature, dim=-1)
296
+ next_token = torch.multinomial(probs, num_samples=1)
297
+ else:
298
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
299
+
300
+ if dream_style and num_block < num_blocks - 1:
301
+ # refresh context-next logit for the next block
302
+ next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
303
+
304
+ if eos_token_id is not None:
305
+ gen_so_far = x_accum[:, prompt.size(1):] # (B, gen_len_so_far)
306
+ is_eos = (gen_so_far == eos_token_id) # (B, gen_len_so_far)
307
+ has_eos = is_eos.any(dim=1) # (B,)
308
+ if has_eos.all():
309
+ first_eos_pos = is_eos.to(torch.int64).argmax(dim=1) # (B,)
310
+ max_eos = first_eos_pos.max().item()
311
+ return x_accum[:, : prompt.size(1) + max_eos + 1], nfe
312
+
313
+ return x_accum, nfe
config.json CHANGED
@@ -1,21 +1,31 @@
1
  {
 
 
 
 
2
  "ar_loss_weight": 1.0,
3
  "architectures": [
4
- "NemotronLabsDiffusionModel"
5
  ],
6
  "attention_bias": false,
7
  "attention_dropout": 0.0,
8
  "attn_implementation": "sdpa",
9
  "auto_map": {
10
- "AutoConfig": "configuration_nemotron_labs_diffusion.NemotronLabsDiffusionConfig",
11
- "AutoModel": "modeling_nemotron_labs_diffusion.NemotronLabsDiffusionModel"
12
  },
13
  "block_size": 32,
14
  "bos_token_id": 1,
 
 
15
  "dlm_loss_weight": null,
16
  "dlm_paradigm": "bidirectional",
 
17
  "dp_varying_mask_ratio": false,
 
 
18
  "eos_token_id": 11,
 
19
  "head_dim": 128,
20
  "hidden_act": "silu",
21
  "hidden_size": 3072,
@@ -24,10 +34,16 @@
24
  "mask_token_id": 100,
25
  "max_position_embeddings": 262144,
26
  "mlp_bias": false,
27
- "model_type": "nemotron_labs_diffusion",
 
 
28
  "num_attention_heads": 32,
 
29
  "num_hidden_layers": 26,
30
  "num_key_value_heads": 8,
 
 
 
31
  "rms_norm_eps": 1e-05,
32
  "rope_parameters": {
33
  "beta_fast": 32.0,
@@ -42,6 +58,7 @@
42
  },
43
  "sliding_window": null,
44
  "tie_word_embeddings": false,
 
45
  "torch_dtype": "bfloat16",
46
  "transformers_version": "5.0.0",
47
  "use_cache": false,
 
1
  {
2
+ "ada_dlm_loss_ratio": null,
3
+ "ada_perm_ratio_global": null,
4
+ "ada_perm_ratio_per_block": null,
5
+ "adaptive_mask_rate": false,
6
  "ar_loss_weight": 1.0,
7
  "architectures": [
8
+ "MinistralDiffEncoderModel"
9
  ],
10
  "attention_bias": false,
11
  "attention_dropout": 0.0,
12
  "attn_implementation": "sdpa",
13
  "auto_map": {
14
+ "AutoConfig": "configuration_ministral_dlm.MinistralDLMConfig",
15
+ "AutoModel": "modeling_ministral_dlm.MinistralDiffEncoderModel"
16
  },
17
  "block_size": 32,
18
  "bos_token_id": 1,
19
+ "diff_loss_weight": 1,
20
+ "dlm_arch": "encoder",
21
  "dlm_loss_weight": null,
22
  "dlm_paradigm": "bidirectional",
23
+ "dlm_type": "llada",
24
  "dp_varying_mask_ratio": false,
25
+ "enable_self_spec": false,
26
+ "enforce_mask": false,
27
  "eos_token_id": 11,
28
+ "global_loss_avg": false,
29
  "head_dim": 128,
30
  "hidden_act": "silu",
31
  "hidden_size": 3072,
 
34
  "mask_token_id": 100,
35
  "max_position_embeddings": 262144,
36
  "mlp_bias": false,
37
+ "model_type": "ministral_dlm",
38
+ "multi_sampling": null,
39
+ "num_ar_layers": 0,
40
  "num_attention_heads": 32,
41
+ "num_diffusion_layers": 0,
42
  "num_hidden_layers": 26,
43
  "num_key_value_heads": 8,
44
+ "num_skip_loss_tokens": 0,
45
+ "prefix_ratio": 0.8,
46
+ "random_length_prob": 0,
47
  "rms_norm_eps": 1e-05,
48
  "rope_parameters": {
49
  "beta_fast": 32.0,
 
58
  },
59
  "sliding_window": null,
60
  "tie_word_embeddings": false,
61
+ "tok_mask_half_life_ratio": null,
62
  "torch_dtype": "bfloat16",
63
  "transformers_version": "5.0.0",
64
  "use_cache": false,
configuration_nemotron_labs_diffusion.py → configuration_ministral_dlm.py RENAMED
@@ -12,7 +12,7 @@
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
- """Nemotron-Labs Diffusion model configuration"""
16
 
17
  from transformers.configuration_utils import PretrainedConfig
18
  from transformers.modeling_rope_utils import rope_config_validation
@@ -22,10 +22,10 @@ from transformers.utils import logging
22
  logger = logging.get_logger(__name__)
23
 
24
 
25
- class NemotronLabsDiffusionConfig(PretrainedConfig):
26
  r"""
27
- This is the configuration class to store the configuration of a [`NemotronLabsDiffusionModel`] for diffusion language models.
28
- It is used to instantiate a NemotronLabsDiffusionModel according to the specified arguments, defining the model architecture.
29
 
30
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
  documentation from [`PretrainedConfig`] for more information.
@@ -72,19 +72,52 @@ class NemotronLabsDiffusionConfig(PretrainedConfig):
72
  Sliding window attention size.
73
  mask_token_id (`int`, *optional*, defaults to -1):
74
  Token ID for masking in diffusion.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  dlm_paradigm (`str`, *optional*, defaults to 'bidirectional'):
76
- Paradigm for diffusion ('bidirectional', 'autoregressive', 'block_diff').
 
 
77
  block_size (`int`, *optional*, defaults to 32):
78
  Block size for block diffusion paradigms.
 
 
 
 
 
 
 
 
79
  dlm_loss_weight (`float`, *optional*):
80
  Weight for diffusion LM loss.
81
  ar_loss_weight (`float`, *optional*, defaults to 1.0):
82
- Weight for autoregressive loss in block_diff paradigm. Use 10000 to only use AR loss.
 
 
83
  dp_varying_mask_ratio (`bool`, *optional*, defaults to False):
84
  Whether to use varying mask ratio for each DP rank during sampling.
 
 
 
 
 
 
 
85
  """
86
 
87
- model_type = "nemotron_labs_diffusion"
88
  keys_to_ignore_at_inference = ["past_key_values"]
89
 
90
  # Default tensor parallel plan for base model `Ministral`
@@ -129,11 +162,28 @@ class NemotronLabsDiffusionConfig(PretrainedConfig):
129
  sliding_window=None,
130
  attn_implementation="sdpa",
131
  mask_token_id=-1,
 
 
 
 
 
 
 
132
  dlm_paradigm='bidirectional',
 
133
  block_size=32,
 
 
 
 
134
  dlm_loss_weight=None,
135
  ar_loss_weight=1.0,
 
136
  dp_varying_mask_ratio=False,
 
 
 
 
137
  **kwargs,
138
  ):
139
  self.vocab_size = vocab_size
@@ -168,11 +218,28 @@ class NemotronLabsDiffusionConfig(PretrainedConfig):
168
  self.attn_implementation = attn_implementation
169
 
170
  self.mask_token_id = mask_token_id
 
 
 
 
 
 
 
171
  self.dlm_paradigm = dlm_paradigm
 
172
  self.block_size = block_size
 
 
 
 
173
  self.dlm_loss_weight = dlm_loss_weight
174
  self.ar_loss_weight = ar_loss_weight
 
175
  self.dp_varying_mask_ratio = dp_varying_mask_ratio
 
 
 
 
176
  super().__init__(
177
  pad_token_id=pad_token_id,
178
  bos_token_id=bos_token_id,
@@ -182,5 +249,5 @@ class NemotronLabsDiffusionConfig(PretrainedConfig):
182
  )
183
 
184
 
185
- __all__ = ["NemotronLabsDiffusionConfig"]
186
 
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
+ """Ministral DLM model configuration"""
16
 
17
  from transformers.configuration_utils import PretrainedConfig
18
  from transformers.modeling_rope_utils import rope_config_validation
 
22
  logger = logging.get_logger(__name__)
23
 
24
 
25
+ class MinistralDLMConfig(PretrainedConfig):
26
  r"""
27
+ This is the configuration class to store the configuration of a [`Ministral3Model`] for diffusion language models.
28
+ It is used to instantiate a Ministral model according to the specified arguments, defining the model architecture.
29
 
30
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
  documentation from [`PretrainedConfig`] for more information.
 
72
  Sliding window attention size.
73
  mask_token_id (`int`, *optional*, defaults to -1):
74
  Token ID for masking in diffusion.
75
+ dlm_type (`str`, *optional*, defaults to 'llada'):
76
+ Type of diffusion language model ('llada', 'dream').
77
+ random_length_prob (`float`, *optional*):
78
+ Probability of using random lengths during training.
79
+ num_ar_layers (`int`, *optional*, defaults to 0):
80
+ Number of autoregressive layers.
81
+ num_diffusion_layers (`int`, *optional*, defaults to 0):
82
+ Number of diffusion layers.
83
+ diff_loss_weight (`float`, *optional*, defaults to 1):
84
+ Weight for diffusion loss.
85
+ enforce_mask (`bool`, *optional*, defaults to False):
86
+ Whether to enforce masking.
87
+ prefix_ratio (`float`, *optional*, defaults to 0.8):
88
+ Ratio for prefix in prefix_bidirectional mode.
89
  dlm_paradigm (`str`, *optional*, defaults to 'bidirectional'):
90
+ Paradigm for diffusion ('bidirectional', 'autoregressive', 'prefix_bidirectional', 'efficient_block_diff', 'block_diff', 'sbd_block_diff').
91
+ dlm_arch (`str`, *optional*, defaults to 'encoder'):
92
+ Architecture type ('encoder', 'encoder_decoder').
93
  block_size (`int`, *optional*, defaults to 32):
94
  Block size for block diffusion paradigms.
95
+ tok_mask_half_life_ratio (`float`, *optional*):
96
+ Half-life ratio for token masking.
97
+ adaptive_mask_rate (`bool`, *optional*, defaults to False):
98
+ Whether to use adaptive mask rate.
99
+ multi_sampling (`int`, *optional*):
100
+ Number of samples for multi-sampling.
101
+ num_skip_loss_tokens (`int`, *optional*, defaults to 0):
102
+ Number of tokens to skip in loss calculation.
103
  dlm_loss_weight (`float`, *optional*):
104
  Weight for diffusion LM loss.
105
  ar_loss_weight (`float`, *optional*, defaults to 1.0):
106
+ Weight for autoregressive loss in sbd_block_diff paradigm. Use 10000 to only use AR loss.
107
+ global_loss_avg (`bool`, *optional*, defaults to False):
108
+ Whether to use global loss average.
109
  dp_varying_mask_ratio (`bool`, *optional*, defaults to False):
110
  Whether to use varying mask ratio for each DP rank during sampling.
111
+ ada_perm_ratio_per_block (`float`, *optional*):
112
+ Adaptive permutation ratio for each block.
113
+ ada_perm_ratio_global (`float`, *optional*):
114
+ Adaptive permutation ratio for global.
115
+ enable_self_spec (`bool`, *optional*, defaults to `False`):
116
+ Force MinistralFlexAttention for all paradigms (including bidirectional/autoregressive).
117
+ Required for self speculative generation; leave False for standard eval to use faster SDPA kernels.
118
  """
119
 
120
+ model_type = "ministral_dlm"
121
  keys_to_ignore_at_inference = ["past_key_values"]
122
 
123
  # Default tensor parallel plan for base model `Ministral`
 
162
  sliding_window=None,
163
  attn_implementation="sdpa",
164
  mask_token_id=-1,
165
+ dlm_type='llada',
166
+ random_length_prob=None,
167
+ num_ar_layers=0,
168
+ num_diffusion_layers=0,
169
+ diff_loss_weight=1,
170
+ enforce_mask=False,
171
+ prefix_ratio=0.8,
172
  dlm_paradigm='bidirectional',
173
+ dlm_arch='encoder',
174
  block_size=32,
175
+ tok_mask_half_life_ratio=None,
176
+ adaptive_mask_rate=False,
177
+ multi_sampling=None,
178
+ num_skip_loss_tokens=0,
179
  dlm_loss_weight=None,
180
  ar_loss_weight=1.0,
181
+ global_loss_avg=False,
182
  dp_varying_mask_ratio=False,
183
+ ada_perm_ratio_per_block=None,
184
+ ada_perm_ratio_global=None,
185
+ ada_dlm_loss_ratio=None,
186
+ enable_self_spec=False,
187
  **kwargs,
188
  ):
189
  self.vocab_size = vocab_size
 
218
  self.attn_implementation = attn_implementation
219
 
220
  self.mask_token_id = mask_token_id
221
+ self.dlm_type = dlm_type
222
+ self.random_length_prob = random_length_prob
223
+ self.num_ar_layers = num_ar_layers
224
+ self.num_diffusion_layers = num_diffusion_layers
225
+ self.diff_loss_weight = diff_loss_weight
226
+ self.enforce_mask = enforce_mask
227
+ self.prefix_ratio = prefix_ratio
228
  self.dlm_paradigm = dlm_paradigm
229
+ self.dlm_arch = dlm_arch
230
  self.block_size = block_size
231
+ self.tok_mask_half_life_ratio = tok_mask_half_life_ratio
232
+ self.adaptive_mask_rate = adaptive_mask_rate
233
+ self.multi_sampling = multi_sampling
234
+ self.num_skip_loss_tokens = num_skip_loss_tokens
235
  self.dlm_loss_weight = dlm_loss_weight
236
  self.ar_loss_weight = ar_loss_weight
237
+ self.global_loss_avg = global_loss_avg
238
  self.dp_varying_mask_ratio = dp_varying_mask_ratio
239
+ self.ada_perm_ratio_per_block = ada_perm_ratio_per_block
240
+ self.ada_perm_ratio_global = ada_perm_ratio_global
241
+ self.ada_dlm_loss_ratio = ada_dlm_loss_ratio
242
+ self.enable_self_spec = enable_self_spec
243
  super().__init__(
244
  pad_token_id=pad_token_id,
245
  bos_token_id=bos_token_id,
 
249
  )
250
 
251
 
252
+ __all__ = ["MinistralDLMConfig"]
253
 
generation_config.json CHANGED
@@ -2,6 +2,6 @@
2
  "_from_model_config": true,
3
  "bos_token_id": 1,
4
  "eos_token_id": 11,
5
- "transformers_version": "5.0.0",
6
  "use_cache": false
7
  }
 
2
  "_from_model_config": true,
3
  "bos_token_id": 1,
4
  "eos_token_id": 11,
5
+ "transformers_version": "4.55.4",
6
  "use_cache": false
7
  }
linear_spec_lora/adapter_config.json DELETED
@@ -1,34 +0,0 @@
1
- {
2
- "alpha_pattern": {},
3
- "auto_mapping": {
4
- "base_model_class": "NemotronLabsDiffusionModel",
5
- "parent_library": "transformers_modules.Nemotron-Labs-Diffusion-3B.modeling_nemotron_labs_diffusion"
6
- },
7
- "base_model_name_or_path": "nvidia/Nemotron-Labs-Diffusion-3B",
8
- "bias": "none",
9
- "eva_config": null,
10
- "exclude_modules": null,
11
- "fan_in_fan_out": false,
12
- "inference_mode": true,
13
- "init_lora_weights": true,
14
- "layer_replication": null,
15
- "layers_pattern": null,
16
- "layers_to_transform": null,
17
- "loftq_config": {},
18
- "lora_alpha": 512,
19
- "lora_bias": false,
20
- "lora_dropout": 0.0,
21
- "megatron_config": null,
22
- "megatron_core": "megatron.core",
23
- "modules_to_save": null,
24
- "peft_type": "LORA",
25
- "r": 128,
26
- "rank_pattern": {},
27
- "revision": null,
28
- "target_modules": [
29
- "o_proj"
30
- ],
31
- "task_type": null,
32
- "use_dora": false,
33
- "use_rslora": false
34
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
linear_spec_lora/adapter_model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:897ef67dff8a69bd1a908fa390ef2164fdaa738e0e47bec502e2f0d86311ff74
3
- size 95427600
 
 
 
 
model_cards/bias.md DELETED
@@ -1,4 +0,0 @@
1
- Field | Response
2
- :---------------------------------------------------------------------------------------------------|:---------------
3
- Participation considerations from adversely impacted groups [protected classes](https://www.senate.ca.gov/content/protected-classes) in model design and testing: | [None]
4
- Measures taken to mitigate against unwanted bias: | [None]
 
 
 
 
 
model_cards/explainability.md DELETED
@@ -1,13 +0,0 @@
1
- Field | Response
2
- :------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------
3
- Intended Task/Domain: | Text generation
4
- Model Type: | Transformer
5
- Intended Users: | Generative AI creators working with conversational AI models.
6
- Output: | Text (Responds to posed question, Stateful - remembers previous answers)
7
- Describe how the model works: | Text input is encoded into tokens and passed into a transformer-based language model, which returns a text response.
8
- Name the adversely impacted groups this has been tested to deliver comparable outcomes regardless of: | Not Applicable
9
- Technical Limitations & Mitigation: | The model cannot perform long-horizon reasoning and tool calling.
10
- Verified to have met prescribed NVIDIA quality standards: | Yes
11
- Performance Metrics: | Accuracy, Latency, Throughput
12
- Potential Known Risks: | In some instances, the model may think too long and struggle to derive final answers. The model's output can generate all forms of text, including what may be considered toxic, offensive, or indecent.
13
- Licensing: | nvidia-open-model-license.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_cards/privacy.md DELETED
@@ -1,11 +0,0 @@
1
- Field | Response
2
- :----------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------
3
- Generatable or reverse engineerable personal data? | [No]
4
- Personal data used to create this model? | [No]
5
- Was consent obtained for any personal data used? | [Not Applicable]
6
- How often is dataset reviewed? | [During dataset creation, model training, evaluation, and the prerelease phase.]
7
- Was data from user interactions with the AI model (e.g. user input and prompts) used to train the model? | [Yes]
8
- Is there provenance for all datasets used in training? | Yes
9
- Does data labeling (annotation, metadata) comply with privacy laws? | Yes
10
- Is data compliant with data subject requests for data correction or removal, if such a request was made? | Not Applicable.
11
- Applicable Privacy Policy | https://www.nvidia.com/en-us/about-nvidia/privacy-policy/
 
 
 
 
 
 
 
 
 
 
 
 
model_cards/safety.md DELETED
@@ -1,6 +0,0 @@
1
- Field | Response
2
- :---------------------------------------------------|:----------------------------------
3
- Model Application Field(s): | [Media & Entertainment].
4
- Describe the life critical impact (if present). | Not Applicable
5
- Model and dataset restrictions: | The Principle of least privilege (PoLP) is applied limiting access for dataset generation and model development. Restrictions enforce dataset access during training, and dataset license constraints adhered to.
6
- Use Case Restrictions: | Abide by nvidia-open-model-license.
 
 
 
 
 
 
 
modeling_ministral.py CHANGED
@@ -25,7 +25,7 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
25
  from transformers.processing_utils import Unpack
26
  from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
27
  # from transformers.utils.generic import maybe_autocast
28
- from .configuration_nemotron_labs_diffusion import NemotronLabsDiffusionConfig
29
 
30
  #ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
31
 
@@ -110,7 +110,7 @@ def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_positi
110
  class Ministral3Attention(nn.Module):
111
  """Multi-headed attention from 'Attention Is All You Need' paper"""
112
 
113
- def __init__(self, config: NemotronLabsDiffusionConfig, layer_idx: int):
114
  super().__init__()
115
  self.config = config
116
  self.layer_idx = layer_idx
@@ -234,7 +234,7 @@ class Ministral3RMSNorm(nn.Module):
234
 
235
 
236
  class Ministral3DecoderLayer(GradientCheckpointingLayer):
237
- def __init__(self, config: NemotronLabsDiffusionConfig, layer_idx: int):
238
  super().__init__()
239
  self.hidden_size = config.hidden_size
240
 
@@ -284,7 +284,7 @@ class Ministral3DecoderLayer(GradientCheckpointingLayer):
284
 
285
  @auto_docstring
286
  class Ministral3PreTrainedModel(PreTrainedModel):
287
- config: NemotronLabsDiffusionConfig
288
  base_model_prefix = "model"
289
  supports_gradient_checkpointing = True
290
  _no_split_modules = ["Ministral3DecoderLayer"]
@@ -304,7 +304,7 @@ class Ministral3PreTrainedModel(PreTrainedModel):
304
  class Ministral3RotaryEmbedding(nn.Module):
305
  inv_freq: torch.Tensor # fix linting for `register_buffer`
306
 
307
- def __init__(self, config: NemotronLabsDiffusionConfig, device=None):
308
  super().__init__()
309
  self.max_seq_len_cached = config.max_position_embeddings
310
  self.original_max_seq_len = config.max_position_embeddings
@@ -323,7 +323,7 @@ class Ministral3RotaryEmbedding(nn.Module):
323
 
324
  @staticmethod
325
  def compute_default_rope_parameters(
326
- config: Optional[NemotronLabsDiffusionConfig] = None,
327
  device: Optional["torch.device"] = None,
328
  seq_len: Optional[int] = None,
329
  ) -> tuple["torch.Tensor", float]:
@@ -370,7 +370,7 @@ class Ministral3RotaryEmbedding(nn.Module):
370
 
371
  @auto_docstring
372
  class Ministral3Model(Ministral3PreTrainedModel):
373
- def __init__(self, config: NemotronLabsDiffusionConfig):
374
  super().__init__(config)
375
  self.padding_idx = config.pad_token_id
376
  self.vocab_size = config.vocab_size
@@ -453,7 +453,99 @@ class Ministral3Model(Ministral3PreTrainedModel):
453
  )
454
 
455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
  __all__ = [
 
 
457
  "Ministral3Model",
458
  "Ministral3PreTrainedModel",
 
 
459
  ]
 
25
  from transformers.processing_utils import Unpack
26
  from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
27
  # from transformers.utils.generic import maybe_autocast
28
+ from .configuration_ministral_dlm import MinistralDLMConfig
29
 
30
  #ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
31
 
 
110
  class Ministral3Attention(nn.Module):
111
  """Multi-headed attention from 'Attention Is All You Need' paper"""
112
 
113
+ def __init__(self, config: MinistralDLMConfig, layer_idx: int):
114
  super().__init__()
115
  self.config = config
116
  self.layer_idx = layer_idx
 
234
 
235
 
236
  class Ministral3DecoderLayer(GradientCheckpointingLayer):
237
+ def __init__(self, config: MinistralDLMConfig, layer_idx: int):
238
  super().__init__()
239
  self.hidden_size = config.hidden_size
240
 
 
284
 
285
  @auto_docstring
286
  class Ministral3PreTrainedModel(PreTrainedModel):
287
+ config: MinistralDLMConfig
288
  base_model_prefix = "model"
289
  supports_gradient_checkpointing = True
290
  _no_split_modules = ["Ministral3DecoderLayer"]
 
304
  class Ministral3RotaryEmbedding(nn.Module):
305
  inv_freq: torch.Tensor # fix linting for `register_buffer`
306
 
307
+ def __init__(self, config: MinistralDLMConfig, device=None):
308
  super().__init__()
309
  self.max_seq_len_cached = config.max_position_embeddings
310
  self.original_max_seq_len = config.max_position_embeddings
 
323
 
324
  @staticmethod
325
  def compute_default_rope_parameters(
326
+ config: Optional[MinistralDLMConfig] = None,
327
  device: Optional["torch.device"] = None,
328
  seq_len: Optional[int] = None,
329
  ) -> tuple["torch.Tensor", float]:
 
370
 
371
  @auto_docstring
372
  class Ministral3Model(Ministral3PreTrainedModel):
373
+ def __init__(self, config: MinistralDLMConfig):
374
  super().__init__(config)
375
  self.padding_idx = config.pad_token_id
376
  self.vocab_size = config.vocab_size
 
453
  )
454
 
455
 
456
+ @auto_docstring
457
+ class Ministral3ForCausalLM(Ministral3PreTrainedModel, GenerationMixin):
458
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
459
+ _tp_plan = {"lm_head": "colwise_rep"}
460
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
461
+
462
+ def __init__(self, config):
463
+ super().__init__(config)
464
+ self.model = Ministral3Model(config)
465
+ self.vocab_size = config.vocab_size
466
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
467
+
468
+ # Initialize weights and apply final processing
469
+ self.post_init()
470
+
471
+ @can_return_tuple
472
+ @auto_docstring
473
+ def forward(
474
+ self,
475
+ input_ids: Optional[torch.LongTensor] = None,
476
+ attention_mask: Optional[torch.Tensor] = None,
477
+ position_ids: Optional[torch.LongTensor] = None,
478
+ past_key_values: Optional[Cache] = None,
479
+ inputs_embeds: Optional[torch.FloatTensor] = None,
480
+ labels: Optional[torch.LongTensor] = None,
481
+ use_cache: Optional[bool] = None,
482
+ cache_position: Optional[torch.LongTensor] = None,
483
+ logits_to_keep: Union[int, torch.Tensor] = 0,
484
+ **kwargs: Unpack[TransformersKwargs],
485
+ ) -> CausalLMOutputWithPast:
486
+ r"""
487
+ Example:
488
+
489
+ ```python
490
+ >>> from transformers import AutoTokenizer, Ministral3ForCausalLM
491
+
492
+ >>> model = Ministral3ForCausalLM.from_pretrained("meta-ministral3/Ministral3-2-7b-hf")
493
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-ministral3/Ministral3-2-7b-hf")
494
+
495
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
496
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
497
+
498
+ >>> # Generate
499
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
500
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
501
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
502
+ ```"""
503
+ outputs: BaseModelOutputWithPast = self.model(
504
+ input_ids=input_ids,
505
+ attention_mask=attention_mask,
506
+ position_ids=position_ids,
507
+ past_key_values=past_key_values,
508
+ inputs_embeds=inputs_embeds,
509
+ use_cache=use_cache,
510
+ cache_position=cache_position,
511
+ **kwargs,
512
+ )
513
+
514
+ hidden_states = outputs.last_hidden_state
515
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
516
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
517
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
518
+
519
+ loss = None
520
+ if labels is not None:
521
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
522
+
523
+ return CausalLMOutputWithPast(
524
+ loss=loss,
525
+ logits=logits,
526
+ past_key_values=outputs.past_key_values,
527
+ hidden_states=outputs.hidden_states,
528
+ attentions=outputs.attentions,
529
+ )
530
+
531
+
532
+ class Ministral3ForTokenClassification(GenericForTokenClassification, Ministral3PreTrainedModel):
533
+ pass
534
+
535
+
536
+ class Ministral3ForSequenceClassification(GenericForSequenceClassification, Ministral3PreTrainedModel):
537
+ pass
538
+
539
+
540
+ class Ministral3ForQuestionAnswering(GenericForQuestionAnswering, Ministral3PreTrainedModel):
541
+ pass
542
+
543
+
544
  __all__ = [
545
+ "Ministral3ForCausalLM",
546
+ "Ministral3ForQuestionAnswering",
547
  "Ministral3Model",
548
  "Ministral3PreTrainedModel",
549
+ "Ministral3ForSequenceClassification",
550
+ "Ministral3ForTokenClassification",
551
  ]
modeling_ministral_dlm.py ADDED
@@ -0,0 +1,1860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Optional, Tuple, Union
4
+ import random
5
+ import os
6
+ import sys
7
+ import json
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
14
+ from transformers.utils import ModelOutput
15
+
16
+ from torch.nn.attention.flex_attention import BlockMask, flex_attention, create_block_mask, or_masks
17
+
18
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
19
+
20
+ from transformers.processing_utils import Unpack
21
+
22
+ from transformers.cache_utils import Cache, DynamicCache
23
+
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from transformers.generation import GenerationMixin
27
+
28
+ import math
29
+
30
+ from .chat_utils import generate_with_prefix_cache_block_diff
31
+ from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
32
+ from .configuration_ministral_dlm import MinistralDLMConfig
33
+
34
+ __all__ = ["MinistralDiffEncoderModel", "MinistralFlexAttention"]
35
+
36
+ @dataclass
37
+ class MinistralDiffOutputWithPast(ModelOutput):
38
+ loss: torch.FloatTensor | None = None
39
+ logits: torch.FloatTensor | None = None
40
+ causal_logits: torch.FloatTensor | None = None
41
+ past_key_values: Cache | None = None
42
+ hidden_states: tuple[torch.FloatTensor, ...] | None = None
43
+ attentions: tuple[torch.FloatTensor, ...] | None = None
44
+
45
+
46
+ # @torch.compile(dynamic=True, mode="reduce-overhead")
47
+ # @torch.compile(mode="default")
48
+ # @torch.compile(fullgraph=True, mode="reduce-overhead", dynamic=False)
49
+ @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs", dynamic=False)
50
+ def fused_flex_attention(q, k, v, block_mask=None):
51
+ return flex_attention(q, k, v, block_mask=block_mask)
52
+
53
+
54
+ def _crop_dynamic_cache(past_key_values: DynamicCache, max_length: int):
55
+ """Crop a DynamicCache to max_length, compatible with both old and new transformers."""
56
+ if hasattr(past_key_values, 'crop'):
57
+ past_key_values.crop(max_length)
58
+ else:
59
+ for layer_idx in range(len(past_key_values)):
60
+ past_key_values.key_cache[layer_idx] = past_key_values.key_cache[layer_idx][:, :, :max_length]
61
+ past_key_values.value_cache[layer_idx] = past_key_values.value_cache[layer_idx][:, :, :max_length]
62
+ past_key_values._seen_tokens = max_length
63
+
64
+
65
+ def _extract_draft_kv_cache(past_key_values: DynamicCache, clean_len: int, block_length: int):
66
+ """After quadratic decoding, extract only draft tokens (first of each block) from cache."""
67
+ for layer_idx in range(len(past_key_values)):
68
+ if hasattr(past_key_values, 'layers'):
69
+ layer_cache = past_key_values.layers[layer_idx]
70
+ k, v = layer_cache.keys, layer_cache.values
71
+ else:
72
+ k = past_key_values.key_cache[layer_idx]
73
+ v = past_key_values.value_cache[layer_idx]
74
+
75
+ clean_k, draft_k = k[:, :, :clean_len], k[:, :, clean_len::block_length + 1]
76
+ clean_v, draft_v = v[:, :, :clean_len], v[:, :, clean_len::block_length + 1]
77
+ new_k = torch.cat([clean_k, draft_k], dim=2)
78
+ new_v = torch.cat([clean_v, draft_v], dim=2)
79
+
80
+ if hasattr(past_key_values, 'layers'):
81
+ layer_cache.keys = new_k
82
+ layer_cache.values = new_v
83
+ else:
84
+ past_key_values.key_cache[layer_idx] = new_k
85
+ past_key_values.value_cache[layer_idx] = new_v
86
+
87
+ past_key_values._seen_tokens = clean_len + block_length
88
+
89
+
90
+ # with reference to https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
91
+ class MinistralFlexAttention(Ministral3Attention):
92
+ def __init__(self, *args, **kwargs):
93
+ super().__init__(*args, **kwargs)
94
+
95
+ self.max_seq_length = getattr(self.config, 'max_seq_length', 4096)
96
+ self.block_size_orig = self.config.block_size
97
+
98
+ if self.config.dlm_paradigm == 'bidirectional':
99
+ self.bidirectional_mask = self.compute_block_mask(mode='bidirectional')
100
+ elif self.config.dlm_paradigm == 'autoregressive':
101
+ self.autoregressive_mask = self.compute_block_mask(mode='autoregressive')
102
+ elif self.config.dlm_paradigm == 'block_diff':
103
+ self.block_diff_mask = None
104
+ elif self.config.dlm_paradigm == 'sbd_block_diff':
105
+ self.sbd_block_diff_mask = None
106
+ else:
107
+ raise ValueError(f"Unknown attention mode: {self.config.dlm_paradigm}")
108
+
109
+ self.block_size = self.block_size_orig
110
+ self.mode = self.config.dlm_paradigm
111
+ self._quadratic_block_mask = {}
112
+
113
+ import torch._dynamo.config as dcfg
114
+ dcfg.cache_size_limit = 512
115
+
116
+
117
+ def _get_sbd_inference_quadratic_decoding_block_mask(self, block_length: int):
118
+ if block_length not in self._quadratic_block_mask:
119
+ draft_len = block_length * (block_length + 1)
120
+
121
+ def quadratic(b, h, q_idx, kv_idx):
122
+ first_clean = torch.logical_and(
123
+ kv_idx % (block_length + 1) == 0,
124
+ kv_idx < draft_len,
125
+ )
126
+ first_clean = torch.logical_and(first_clean, q_idx >= kv_idx)
127
+ block_q = q_idx // (block_length + 1)
128
+ block_kv = kv_idx // (block_length + 1)
129
+ same_block = torch.logical_and(block_q == block_kv, q_idx < draft_len)
130
+ same_block_except_first = torch.logical_and(
131
+ same_block,
132
+ q_idx % (block_length + 1) != 0,
133
+ )
134
+ draft_part = torch.logical_or(first_clean, same_block_except_first)
135
+ clean_part = kv_idx >= draft_len
136
+ return torch.logical_or(draft_part, clean_part)
137
+
138
+ block_mask = create_block_mask(
139
+ quadratic,
140
+ B=None,
141
+ H=None,
142
+ Q_LEN=draft_len,
143
+ KV_LEN=draft_len + self.config.max_position_embeddings,
144
+ device="cuda",
145
+ )
146
+
147
+ self._quadratic_block_mask[block_length] = block_mask
148
+
149
+ return self._quadratic_block_mask[block_length]
150
+
151
+
152
+ def set_attention_mode(self, mode, block_size=None):
153
+ self.mode = mode
154
+ self.block_size = block_size
155
+
156
+ def compute_block_mask(self, mode, q_len=None, block_size=None):
157
+
158
+ def bidirectional_mask(b, h, q, kv):
159
+ return (q >= kv) | (q < kv)
160
+
161
+ def autoregressive_mask(b, h, q, kv):
162
+ return (q >= kv)
163
+
164
+ def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
165
+ x0_flag_q = (q_idx >= n)
166
+ x0_flag_kv = (kv_idx >= n)
167
+
168
+ # Compute block indices
169
+ block_q = torch.where(x0_flag_q == 1,
170
+ (q_idx - n) // block_size,
171
+ q_idx // block_size)
172
+ block_kv = torch.where(x0_flag_kv == 1,
173
+ (kv_idx - n) // block_size,
174
+ kv_idx // block_size)
175
+
176
+ # **1. Block Diagonal Mask (M_BD) **
177
+ block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
178
+
179
+ # **2. Offset Block-Causal Mask (M_OBC) **
180
+ offset_block_causal = (
181
+ (block_q > block_kv)
182
+ & (x0_flag_kv == 1)
183
+ & (x0_flag_q == 0)
184
+ )
185
+
186
+ # **3. Block-Causal Mask (M_BC) **
187
+ block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
188
+
189
+ # **4. Combine Masks **
190
+ return block_diagonal | offset_block_causal | block_causal
191
+
192
+
193
+ def sbd_block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
194
+ x0_flag_q = (q_idx >= n)
195
+ x0_flag_kv = (kv_idx >= n)
196
+
197
+ # Compute block indices
198
+ block_q = torch.where(x0_flag_q == 1,
199
+ (q_idx - n) // block_size,
200
+ q_idx // block_size)
201
+ block_kv = torch.where(x0_flag_kv == 1,
202
+ (kv_idx - n) // block_size,
203
+ kv_idx // block_size)
204
+
205
+ # **1. Block Diagonal Mask (M_BD) **
206
+ block_diagonal = (block_q == block_kv) & (x0_flag_kv == 0) & (x0_flag_q == 0)
207
+
208
+ # **2. Offset Block-Causal Mask (M_OBC) **
209
+ offset_block_causal = (
210
+ (block_q > block_kv)
211
+ & (x0_flag_kv == 1)
212
+ & (x0_flag_q == 0)
213
+ )
214
+
215
+ # **3. Fully Causal Mask (M_BC) **
216
+ fully_causal = (q_idx >= kv_idx) & (x0_flag_kv == 1) & (x0_flag_q == 1)
217
+
218
+ # **4. Combine Masks **
219
+ return block_diagonal | offset_block_causal | fully_causal
220
+
221
+ if mode == 'bidirectional':
222
+ attn_mask = bidirectional_mask
223
+ elif mode == 'autoregressive':
224
+ attn_mask = autoregressive_mask
225
+ elif mode == 'block_diff':
226
+ assert block_size is not None
227
+ attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, self.max_seq_length)
228
+ elif mode == 'sbd_block_diff':
229
+ assert block_size is not None
230
+ attn_mask = lambda b, h, q, kv: sbd_block_diff_mask(block_size, b, h, q, kv, self.max_seq_length)
231
+ else:
232
+ raise ValueError(f"Unknown attention mode: {mode}")
233
+
234
+ if q_len is not None:
235
+ Q_LEN = q_len
236
+ else:
237
+ if mode in ['block_diff', 'sbd_block_diff']:
238
+ Q_LEN = self.max_seq_length * 2
239
+ else:
240
+ Q_LEN = self.max_seq_length
241
+
242
+ block_mask = create_block_mask(
243
+ attn_mask, B=None, H=None, Q_LEN=Q_LEN, KV_LEN=Q_LEN
244
+ )
245
+
246
+ return block_mask
247
+
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states: torch.Tensor,
252
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
253
+ attention_mask: Optional[torch.Tensor],
254
+ past_key_values: Optional[Cache] = None,
255
+ cache_position: Optional[torch.LongTensor] = None,
256
+ is_training: bool = True,
257
+ **kwargs: Unpack[FlashAttentionKwargs],
258
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
259
+ bsz, q_len, _ = hidden_states.size()
260
+ input_shape = hidden_states.shape[:-1]
261
+ hidden_shape = (*input_shape, -1, self.head_dim)
262
+
263
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
264
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
265
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
266
+
267
+ cos, sin = position_embeddings
268
+
269
+ if self.mode in ['block_diff', 'sbd_block_diff'] and is_training:
270
+ # Split query and key states in half along sequence length dimension
271
+ q1, q2 = query_states.chunk(2, dim=2)
272
+ k1, k2 = key_states.chunk(2, dim=2)
273
+
274
+ # Apply RoPE independently to each half
275
+ q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
276
+ q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
277
+
278
+ # Recombine the halves
279
+ query_states = torch.cat([q1, q2], dim=2)
280
+ key_states = torch.cat([k1, k2], dim=2)
281
+ else:
282
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
283
+
284
+ query_states = query_states * _get_llama_4_attn_scale(
285
+ cache_position,
286
+ self.config.rope_parameters.get("llama_4_scaling_beta"),
287
+ self.config.rope_parameters.get("original_max_position_embeddings"),
288
+ ).to(query_states.dtype)
289
+
290
+ if past_key_values is not None:
291
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
292
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
293
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
294
+
295
+ self_spec_inference_mode = getattr(self.config, "self_spec_inference_mode", None)
296
+ if self_spec_inference_mode is not None:
297
+ if self_spec_inference_mode == "quadratic":
298
+ block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
299
+ if block_length is None:
300
+ raise ValueError("SBD quadratic decoding requires block_length in config.")
301
+ if past_key_values is not None:
302
+ seq_len = key_states.shape[2]
303
+ draft_len = block_length * (block_length + 1)
304
+
305
+ clean_keys = key_states[:, :, :-draft_len]
306
+ draft_keys = key_states[:, :, -draft_len:]
307
+ clean_values = value_states[:, :, :-draft_len]
308
+ draft_values = value_states[:, :, -draft_len:]
309
+ key_states = torch.cat([draft_keys, clean_keys], dim=2)
310
+ value_states = torch.cat([draft_values, clean_values], dim=2)
311
+
312
+ block_mask: BlockMask = self._get_sbd_inference_quadratic_decoding_block_mask(
313
+ block_length=block_length
314
+ )
315
+ block_mask.seq_lengths = (draft_len, seq_len)
316
+ else:
317
+ seq_len = query_states.shape[2]
318
+ draft_len = block_length * (block_length + 1)
319
+ clean_len = seq_len - draft_len
320
+
321
+ def _causal_mask(b, h, q_idx, kv_idx):
322
+ return torch.logical_and(q_idx >= kv_idx, q_idx < clean_len)
323
+
324
+ def _draft2clean_mask(b, h, q_idx, kv_idx):
325
+ full_clean = torch.logical_and(q_idx >= clean_len, kv_idx <= clean_len)
326
+ first_clean = torch.logical_and(
327
+ q_idx >= clean_len, (kv_idx - clean_len) % (block_length + 1) == 0
328
+ )
329
+ first_clean = torch.logical_and(first_clean, q_idx >= kv_idx)
330
+ return torch.logical_or(full_clean, first_clean)
331
+
332
+ def _draft_mask(b, h, q_idx, kv_idx):
333
+ block_q = (q_idx - clean_len) // (block_length + 1)
334
+ block_kv = (kv_idx - clean_len) // (block_length + 1)
335
+ quadrant = torch.logical_and(q_idx >= clean_len, kv_idx >= clean_len)
336
+ same_block = torch.logical_and(block_q == block_kv, quadrant)
337
+ same_block_except_first = torch.logical_and(
338
+ same_block,
339
+ (q_idx - clean_len) % (block_length + 1) != 0,
340
+ )
341
+ return torch.logical_and(block_q == block_kv, same_block_except_first)
342
+
343
+ mask = or_masks(_causal_mask, _draft2clean_mask)
344
+ mask = or_masks(mask, _draft_mask)
345
+
346
+ block_mask = create_block_mask(
347
+ mask, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,
348
+ )
349
+
350
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
351
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
352
+ attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
353
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
354
+ attn_output = self.o_proj(attn_output)
355
+ return attn_output, None
356
+
357
+ elif self_spec_inference_mode == "default":
358
+ block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
359
+ if block_length is None:
360
+ raise ValueError("SBD default decoding requires block_length in config.")
361
+ seq_len = query_states.shape[2]
362
+ prefix_len = seq_len - block_length
363
+
364
+ def _clean_q_mask(b, h, q_idx, kv_idx):
365
+ return torch.logical_and(q_idx >= kv_idx, q_idx < prefix_len)
366
+
367
+ def _noisy_q_mask(b, h, q_idx, kv_idx):
368
+ return q_idx >= prefix_len
369
+
370
+ block_mask = create_block_mask(
371
+ or_masks(_clean_q_mask, _noisy_q_mask),
372
+ B=None,
373
+ H=None,
374
+ Q_LEN=seq_len,
375
+ KV_LEN=seq_len,
376
+ )
377
+
378
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
379
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
380
+ attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
381
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
382
+ attn_output = self.o_proj(attn_output)
383
+ return attn_output, None
384
+
385
+ else:
386
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
387
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
388
+
389
+ if self.mode == 'bidirectional':
390
+ if self.bidirectional_mask is None or q_len != self.bidirectional_mask.shape[-2]:
391
+ block_mask = self.compute_block_mask(mode='bidirectional', q_len=q_len)
392
+ else:
393
+ block_mask = self.bidirectional_mask
394
+
395
+ elif self.mode == 'autoregressive':
396
+ if self.autoregressive_mask is None or q_len != self.autoregressive_mask.shape[-2]:
397
+ block_mask = self.compute_block_mask(mode='autoregressive', q_len=q_len)
398
+ else:
399
+ block_mask = self.autoregressive_mask
400
+
401
+ elif self.mode == 'block_diff':
402
+ if self.block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.block_diff_mask.shape[-2]:
403
+ block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
404
+ else:
405
+ block_mask = self.block_diff_mask
406
+ elif self.mode == 'sbd_block_diff':
407
+ if self.sbd_block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.sbd_block_diff_mask.shape[-2]:
408
+ block_mask = self.compute_block_mask(mode='sbd_block_diff', block_size=self.block_size, q_len=q_len)
409
+ else:
410
+ block_mask = self.sbd_block_diff_mask
411
+ else:
412
+ raise ValueError(f"Unknown attention mode: {self.mode}")
413
+
414
+ attn_output = fused_flex_attention(query_states, key_states, value_states, block_mask=block_mask)
415
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
416
+
417
+ attn_output = self.o_proj(attn_output)
418
+
419
+ return attn_output, None
420
+
421
+
422
+ def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
423
+ """Return a Bool mask of length len(log_w) with exactly k True."""
424
+ g = -torch.log(-torch.log(torch.rand_like(log_w) + 1e-9) + 1e-9)
425
+ topk = torch.topk(log_w + g, k).indices
426
+ mask = torch.zeros_like(log_w, dtype=torch.bool)
427
+ mask[topk] = True
428
+ return mask
429
+
430
+
431
+ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
432
+ """
433
+ A single model with:
434
+ - a bidirectional encoder + diffusion‐LM head over A
435
+ - a causal decoder + LM head over B, conditioned on F_A
436
+ """
437
+
438
+ def __init__(self, config: MinistralDLMConfig):
439
+ super().__init__(config)
440
+
441
+ self.mask_token_id = config.mask_token_id
442
+
443
+ diffusion_config = copy.deepcopy(config)
444
+ diffusion_config.diffusion_lm = True
445
+
446
+ use_flex = getattr(config, 'enable_self_spec', False)
447
+
448
+ if config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
449
+ diffusion_config.attn_class = MinistralFlexAttention
450
+ elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
451
+ diffusion_config.attn_class = MinistralFlexAttention if use_flex else Ministral3Attention
452
+ if config.dlm_paradigm == 'autoregressive':
453
+ diffusion_config.diffusion_lm = False
454
+ else:
455
+ raise ValueError(f"Unsupported DLM paradigm: {config.dlm_paradigm}")
456
+
457
+ self.encoder = Ministral3Model(diffusion_config)
458
+ self.diffusion_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
459
+ self.vocab_size = config.vocab_size
460
+
461
+ self.current_iter_ratio = None
462
+
463
+ self.post_init()
464
+
465
+
466
+ def get_input_embeddings(self):
467
+ return self.encoder.embed_tokens
468
+
469
+ def set_input_embeddings(self, value):
470
+ self.encoder.embed_tokens = value
471
+
472
+ def get_output_embeddings(self):
473
+ return self.diffusion_head
474
+
475
+ def set_output_embeddings(self, new_embeddings):
476
+ self.diffusion_head = new_embeddings
477
+
478
+
479
+ def forward_process(self, input_ids, eps=1e-3, block_size=None, loss_mask=None):
480
+ b, l = input_ids.shape
481
+ device = input_ids.device
482
+
483
+ if self.config.dp_varying_mask_ratio:
484
+ # Enable different random seeds for each DP rank during sampling
485
+ import torch.distributed as dist
486
+ dp_rank = 0
487
+ if dist.is_initialized():
488
+ try:
489
+ dp_rank = dist.get_rank()
490
+ except Exception:
491
+ dp_rank = 0
492
+ # Use a local generator to avoid affecting global RNG state
493
+ generator = torch.Generator(device=device)
494
+ generator.manual_seed(torch.seed() + dp_rank)
495
+ else:
496
+ generator = None
497
+
498
+ if self.config.adaptive_mask_rate:
499
+ assert block_size is not None
500
+
501
+ # --- simple linear window mapping ---
502
+ bs_min = getattr(self.config, "t_bs_min", 16)
503
+ bs_max = getattr(self.config, "t_bs_max", 128)
504
+ w = getattr(self.config, "t_window_width", 0.6) # fixed width
505
+
506
+ # fraction in [0,1] (unclamped first)
507
+ frac = (float(block_size) - float(bs_min)) / max(1.0, float(bs_max - bs_min))
508
+ # upper bound decreases linearly from 1.0 -> 0.5
509
+ u_max = 1.0 - w * frac
510
+ # clamp to [0.6, 1.0] to handle bs outside [bs_min, bs_max]
511
+ u_max = max(0.6, min(1.0, u_max))
512
+ u_min = u_max - w # ensures width = w
513
+
514
+ # sample t ~ Uniform(u_min, u_max)
515
+ t = u_min + (u_max - u_min) * torch.rand(b, device=device, generator=generator)
516
+ else:
517
+ t = torch.rand(b, device=device, generator=generator)
518
+
519
+ p_mask = (1 - eps) * t + eps # shape: (b,)
520
+ p_mask = p_mask[:, None].expand(-1, l) # shape: (b, l)
521
+
522
+ masked_indices = torch.rand((b, l), device=device) < p_mask
523
+
524
+ if loss_mask is not None:
525
+ masked_indices[loss_mask == 0] = 0
526
+
527
+ noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
528
+
529
+ return noisy_batch, masked_indices, p_mask
530
+
531
+
532
+ def forward_process_exp(
533
+ self,
534
+ input_ids: torch.Tensor,
535
+ eps: float = 1e-3,
536
+ block_size: int | None = None,
537
+ half_life_ratio: float = 0.25, # λ = ln 2 / (half_life_ratio·L)
538
+ loss_mask: Optional[torch.Tensor] = None,
539
+ ):
540
+ """
541
+ Two-stage corruption with optional per-block sampling.
542
+ • Stage 1: m ~ U(eps, 1) → k = round(m · len) (exact budget).
543
+ • Stage 2: sample exactly k positions with weights
544
+ w_i(m) = exp[ λ · (1−m) · i ] (late-heavy when m→0,
545
+ uniform when m→1).
546
+ If `block_size` is given, the procedure is run *independently*
547
+ inside each contiguous block of that length (last block may be shorter).
548
+ When block_size is provided, m is sampled per-block and p_mask is per-block.
549
+ Args
550
+ ----
551
+ input_ids : (B, L) LongTensor
552
+ eps : minimum corruption ratio
553
+ block_size: if not None, operate block-wise with per-block m sampling
554
+ half_life_ratio : controls steepness when m→0
555
+ """
556
+ B, L = input_ids.shape
557
+ device = input_ids.device
558
+ dtype = torch.float32
559
+
560
+ masked_indices = torch.zeros((B, L), dtype=torch.bool, device=device)
561
+ p_mask = torch.zeros((B, L), dtype=dtype, device=device)
562
+
563
+ # ---------- Stage 1 & 2: whole-sentence or block-wise -------------------
564
+ for b in range(B):
565
+ if block_size is None:
566
+ # ---------- Per-batch sampling (original behavior) ----------
567
+ m = eps + (1.0 - eps) * torch.rand(1, device=device).item() # scalar
568
+ k_tot = int(round(m * L))
569
+ k_tot = max(1, min(k_tot, L)) # clamp to [1, L]
570
+
571
+ # Fill p_mask for this batch
572
+ p_mask[b, :] = m
573
+
574
+ slope = 1.0 - m # ∈ [0,1]; 0 ⇒ uniform, 1 ⇒ late-heavy
575
+
576
+ # ------- single pool over the whole sentence -------------
577
+ lam_base = math.log(2.0) / (half_life_ratio * L) # base decay rate (λ when slope=1)
578
+
579
+ pos = torch.arange(L, device=device, dtype=dtype)
580
+ log_w = (lam_base * slope * pos).clone()
581
+
582
+ masked_indices[b] = gumbel_topk(log_w, k_tot)
583
+
584
+ else:
585
+ # ---------- Per-block sampling ----------
586
+ num_blocks = math.ceil(L / block_size)
587
+ lam_base = math.log(2.0) / (half_life_ratio * block_size) # base decay rate (λ when slope=1)
588
+
589
+ for blk in range(num_blocks):
590
+ start = blk * block_size
591
+ end = min((blk + 1) * block_size, L)
592
+ blk_len = end - start
593
+
594
+ # Sample m per block
595
+ m_blk = eps + (1.0 - eps) * torch.rand(1, device=device).item()
596
+
597
+ # Fill p_mask for this block
598
+ p_mask[b, start:end] = m_blk
599
+
600
+ # per-block budget
601
+ k_blk = int(round(m_blk * blk_len))
602
+ k_blk = max(0, min(k_blk, blk_len))
603
+ if k_blk == 0:
604
+ continue
605
+
606
+ slope = 1.0 - m_blk # ∈ [0,1]; 0 ⇒ uniform, 1 ⇒ late-heavy
607
+
608
+ pos = torch.arange(blk_len, device=device, dtype=dtype)
609
+ log_w = lam_base * slope * pos
610
+
611
+ blk_mask = gumbel_topk(log_w, k_blk)
612
+ masked_indices[b, start:end] = blk_mask
613
+
614
+ if loss_mask is not None:
615
+ masked_indices[loss_mask == 0] = 0
616
+
617
+ noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
618
+ return noisy_batch, masked_indices, p_mask
619
+
620
+
621
+ def forward(
622
+ self,
623
+ input_ids: torch.LongTensor,
624
+ attention_mask: Optional[torch.Tensor] = None,
625
+ position_ids: Optional[torch.LongTensor] = None,
626
+ labels: Optional[torch.LongTensor] = None,
627
+ split_len: Optional[int] = None,
628
+ past_key_values: Optional[Cache] = None,
629
+ block_size: Optional[int] = None,
630
+ block_diff_ppl: bool = False,
631
+ eps: float = 1e-3,
632
+ is_teacher: bool = False,
633
+ masked_indices: Optional[torch.Tensor] = None,
634
+ p_mask: Optional[torch.Tensor] = None,
635
+ teacher_logits: Optional[torch.Tensor] = None,
636
+ masked_indices_teacher: Optional[torch.Tensor] = None,
637
+ loss_mask: Optional[torch.Tensor] = None,
638
+ ce_loss_weight: float = 1.0,
639
+ output_last_hidden_states_only: bool = False,
640
+ skip_loss: bool = False,
641
+ **kwargs,
642
+ ) -> CausalLMOutputWithPast:
643
+
644
+ batch_size, seq_len = input_ids.shape
645
+
646
+ if self.config.dlm_paradigm == 'bidirectional' or self.config.dlm_paradigm == 'autoregressive':
647
+ if labels is not None and torch.rand(1) < self.config.random_length_prob:
648
+ random_length = torch.randint(2, input_ids.shape[1] + 1, (1,))
649
+ input_ids = input_ids[:, :random_length]
650
+ labels = labels[:, :random_length]
651
+
652
+ if attention_mask is not None:
653
+ attention_mask = attention_mask[:, :random_length]
654
+ if position_ids is not None:
655
+ position_ids = position_ids[:, :random_length]
656
+ if loss_mask is not None:
657
+ loss_mask = loss_mask[:, :random_length]
658
+
659
+ elif self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
660
+ if labels is not None and block_size is None:
661
+ if torch.rand(1) < self.config.random_length_prob:
662
+ block_size = torch.randint(1, 8, (1,)).item() * 4 ## [4, 32] divisible by 4
663
+ else:
664
+ block_size = self.config.block_size
665
+
666
+ else:
667
+ raise ValueError(f"Unknown dLM paradigm: {self.config.dlm_paradigm}")
668
+
669
+ if labels is not None and self.config.dlm_paradigm != 'autoregressive':
670
+ if masked_indices is not None:
671
+ # assert p_mask is not None
672
+
673
+ if loss_mask is not None:
674
+ masked_indices[loss_mask == 0] = 0
675
+
676
+ noisy_inputs = torch.where(masked_indices, self.mask_token_id, input_ids)
677
+
678
+ else:
679
+ if self.config.tok_mask_half_life_ratio is not None:
680
+ noisy_inputs, masked_indices, p_mask = self.forward_process_exp(input_ids, eps=eps, block_size=block_size, half_life_ratio=self.config.tok_mask_half_life_ratio, loss_mask=loss_mask)
681
+ else:
682
+ noisy_inputs, masked_indices, p_mask = self.forward_process(input_ids, eps=eps, block_size=block_size, loss_mask=loss_mask)
683
+
684
+ else:
685
+ noisy_inputs = input_ids
686
+ masked_indices = None
687
+ p_mask = None
688
+
689
+ if self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
690
+ for layer in self.encoder.layers:
691
+ if hasattr(layer.self_attn, 'set_attention_mode'):
692
+ layer.self_attn.set_attention_mode(self.config.dlm_paradigm, block_size=block_size)
693
+
694
+ input_ids_len = noisy_inputs.shape[1]
695
+ if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
696
+ if position_ids is None:
697
+ position_ids = torch.arange(input_ids_len, device=noisy_inputs.device).unsqueeze(0)
698
+ noisy_inputs = torch.cat([noisy_inputs, input_ids], dim=1)
699
+
700
+ if block_diff_ppl:
701
+ if position_ids is None:
702
+ position_ids = torch.arange(input_ids_len // 2, device=noisy_inputs.device).unsqueeze(0)
703
+
704
+ enc_out = self.encoder(
705
+ past_key_values=past_key_values,
706
+ input_ids=noisy_inputs,
707
+ attention_mask=attention_mask,
708
+ position_ids=position_ids,
709
+ is_training=(labels is not None) or (block_diff_ppl),
710
+ **kwargs,
711
+ )
712
+
713
+ if output_last_hidden_states_only:
714
+ return BaseModelOutput(last_hidden_state=enc_out.last_hidden_state)
715
+
716
+ logits = self.diffusion_head(enc_out.last_hidden_state) # (batch, len_B, vocab)
717
+ causal_logits = None
718
+
719
+ if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
720
+ if self.config.dlm_paradigm == 'sbd_block_diff':
721
+ causal_logits = logits[:, input_ids_len:]
722
+ else:
723
+ causal_logits = None
724
+
725
+ logits = logits[:, :input_ids_len]
726
+
727
+ loss = None
728
+ if labels is not None and not skip_loss:
729
+ if self.config.dlm_paradigm == 'autoregressive':
730
+ shift_logits = logits[..., :-1, :].contiguous()
731
+ shift_labels = labels[..., 1:].contiguous()
732
+
733
+ if loss_mask is None:
734
+ loss_fct = CrossEntropyLoss()
735
+ shift_logits = shift_logits.view(-1, shift_logits.size(-1))
736
+ shift_labels = shift_labels.view(-1)
737
+ loss = loss_fct(shift_logits, shift_labels)
738
+
739
+ else:
740
+ loss_mask = loss_mask[..., 1:].contiguous()
741
+
742
+ loss_fct = CrossEntropyLoss(reduction='none')
743
+ shift_logits = shift_logits.view(-1, shift_logits.size(-1))
744
+ shift_labels = shift_labels.view(-1)
745
+ shift_labels = shift_labels.to(shift_logits.device)
746
+
747
+ token_losses = loss_fct(shift_logits, shift_labels)
748
+
749
+ flat_loss_mask = loss_mask.reshape(-1)
750
+ loss = token_losses[flat_loss_mask == 1].sum() / flat_loss_mask.sum()
751
+
752
+ else:
753
+ # Handle DREAM vs LLADA style losses
754
+ if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
755
+ logits = logits[..., :-1, :].contiguous()
756
+ labels = labels[..., 1:].contiguous()
757
+ masked_indices = masked_indices[:, 1:]
758
+ p_mask = p_mask[:, 1:]
759
+
760
+ if self.config.ada_perm_ratio_per_block is not None:
761
+ # Only compute loss for the top ada_perm_ratio_per_block tokens by confidence within each block
762
+ block_size = self.config.block_size
763
+ batch_size, seq_len = masked_indices.shape
764
+ num_blocks = seq_len // block_size
765
+
766
+ # Get the max logit (confidence) for each position
767
+ confidence = logits.max(dim=-1).values.detach() # (batch_size, seq_len)
768
+
769
+ # Create a mask for tokens to include in loss
770
+ selected_mask = torch.zeros_like(masked_indices, dtype=torch.bool)
771
+
772
+ for blk in range(num_blocks):
773
+ start = blk * block_size
774
+ end = min((blk + 1) * block_size, seq_len)
775
+
776
+ # Get masked indices within this block
777
+ block_masked = masked_indices[:, start:end] # (batch_size, block_len)
778
+ block_confidence = confidence[:, start:end] # (batch_size, block_len)
779
+
780
+ for b in range(batch_size):
781
+ # Get positions that are masked in this block for this batch
782
+ masked_positions = torch.where(block_masked[b])[0]
783
+ num_masked = len(masked_positions)
784
+
785
+ if num_masked > 0:
786
+ # Number of tokens to keep (top by confidence)
787
+ k = min(max(1, int(block_size * self.config.ada_perm_ratio_per_block)), num_masked)
788
+
789
+ # Get confidence values for masked positions
790
+ masked_confidence = block_confidence[b, masked_positions]
791
+
792
+ # Get indices of top-k confident tokens
793
+ _, topk_indices = torch.topk(masked_confidence, k)
794
+ selected_positions = masked_positions[topk_indices]
795
+
796
+ # Mark these positions in the selected mask
797
+ selected_mask[b, start + selected_positions] = True
798
+
799
+ # Calculate loss only for selected positions
800
+ token_loss = torch.nn.functional.cross_entropy(
801
+ logits[selected_mask],
802
+ labels[selected_mask],
803
+ reduction='none'
804
+ ) / p_mask[selected_mask]
805
+
806
+ num_mask_tokens = selected_mask.sum()
807
+
808
+ else:
809
+ # Calculate token-wise cross entropy loss for masked positions in B
810
+ token_loss = torch.nn.functional.cross_entropy(
811
+ logits[masked_indices],
812
+ labels[masked_indices],
813
+ reduction='none'
814
+ ) / p_mask[masked_indices]
815
+
816
+ num_mask_tokens = masked_indices.sum()
817
+
818
+ if self.config.global_loss_avg:
819
+ loss = token_loss.sum()
820
+ else:
821
+ loss = token_loss.sum() / num_mask_tokens
822
+
823
+ if self.config.ada_dlm_loss_ratio is not None:
824
+ assert self.current_iter_ratio is not None
825
+ assert self.config.dlm_loss_weight is not None
826
+
827
+ dlm_loss_weight = min(self.config.dlm_loss_weight, self.current_iter_ratio / self.config.ada_dlm_loss_ratio * self.config.dlm_loss_weight)
828
+ loss = dlm_loss_weight * loss
829
+
830
+ elif self.config.dlm_loss_weight is not None:
831
+ loss = self.config.dlm_loss_weight * loss
832
+
833
+ if self.config.dlm_paradigm == 'sbd_block_diff':
834
+ causal_logits = causal_logits[..., :-1, :].contiguous()
835
+ causal_logits = causal_logits.view(-1, causal_logits.size(-1))
836
+
837
+ if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
838
+ causal_labels = labels.view(-1)
839
+ else:
840
+ causal_labels = labels[..., 1:].contiguous().view(-1)
841
+
842
+ if self.config.global_loss_avg:
843
+ loss_fct = CrossEntropyLoss(reduction='sum')
844
+ ar_loss = loss_fct(causal_logits, causal_labels)
845
+
846
+ self.loss_diffusion = loss.detach().item() / num_mask_tokens
847
+ self.loss_ar = ar_loss.detach().item() / seq_len
848
+
849
+ loss = loss + self.config.ar_loss_weight * ar_loss
850
+ else:
851
+ loss_fct = CrossEntropyLoss()
852
+ ar_loss = loss_fct(causal_logits, causal_labels)
853
+
854
+ self.loss_diffusion = loss.detach().item()
855
+ self.loss_ar = ar_loss.detach().item()
856
+
857
+ loss = loss + self.config.ar_loss_weight * ar_loss
858
+
859
+ if self.config.global_loss_avg:
860
+ if self.config.dlm_paradigm == 'sbd_block_diff':
861
+ loss = (loss, num_mask_tokens + int(self.config.ar_loss_weight * seq_len))
862
+ else:
863
+ loss = (loss, num_mask_tokens)
864
+
865
+ return MinistralDiffOutputWithPast(
866
+ loss=loss if not is_teacher else logits,
867
+ logits=logits,
868
+ causal_logits=causal_logits,
869
+ past_key_values=enc_out.past_key_values,
870
+ hidden_states=None,
871
+ attentions=None,
872
+ )
873
+
874
+
875
+ def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold, causal_context=True, temperature=0, eos_token_id=None, max_thinking_tokens=None, end_think_token_id=None):
876
+ if eos_token_id is None:
877
+ eos_token_id = getattr(self.config, 'eos_token_id', None)
878
+
879
+ out_ids, nfe = generate_with_prefix_cache_block_diff(
880
+ model=self,
881
+ prompt=prompt_ids,
882
+ gen_length=max_new_tokens,
883
+ steps=steps,
884
+ block_length=block_length,
885
+ remasking="low_confidence",
886
+ temperature=temperature,
887
+ mask_id=self.mask_token_id,
888
+ threshold=threshold,
889
+ shift_logits=shift_logits,
890
+ neg_entropy=False,
891
+ causal_context=causal_context,
892
+ eos_token_id=eos_token_id,
893
+ max_thinking_tokens=max_thinking_tokens,
894
+ end_think_token_id=end_think_token_id,
895
+ )
896
+
897
+ return out_ids, nfe
898
+
899
+
900
+ @torch.no_grad()
901
+ def sbd_inference_diffusion_quadratic(
902
+ self,
903
+ clean_input_ids: Optional[torch.Tensor],
904
+ draft_input_ids: torch.Tensor,
905
+ block_length: int,
906
+ draft_only: bool = False,
907
+ past_key_values: Optional[Cache] = None,
908
+ use_cache: bool = False,
909
+ ):
910
+ enc_config = self.encoder.config
911
+ enc_config.use_sbd_objective = True
912
+ enc_config.block_length = block_length
913
+
914
+ if draft_only:
915
+ assert clean_input_ids is not None
916
+
917
+ if use_cache and past_key_values is None:
918
+ past_key_values = DynamicCache()
919
+
920
+ enc_config.self_spec_inference_mode = "default"
921
+ input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
922
+ outputs = self.encoder(
923
+ input_ids=input_ids,
924
+ position_ids=None,
925
+ past_key_values=past_key_values,
926
+ use_cache=use_cache,
927
+ is_training=False,
928
+ )
929
+
930
+ hidden_states = outputs.last_hidden_state
931
+ logits = self.diffusion_head(hidden_states)
932
+
933
+ past_key_values = getattr(outputs, "past_key_values", None)
934
+ if use_cache and past_key_values is not None:
935
+ _crop_dynamic_cache(past_key_values, clean_input_ids.shape[1])
936
+
937
+ return logits, past_key_values
938
+ else:
939
+ enc_config.self_spec_inference_mode = "quadratic"
940
+
941
+ draft_len = block_length * (block_length + 1)
942
+ draft_input_ids = torch.cat(
943
+ [
944
+ draft_input_ids.view(-1, block_length, 1),
945
+ torch.full(
946
+ (draft_input_ids.shape[0], block_length, block_length),
947
+ fill_value=self.config.mask_token_id,
948
+ device=draft_input_ids.device,
949
+ ),
950
+ ],
951
+ dim=-1,
952
+ ).view(-1, draft_len)
953
+
954
+ if use_cache:
955
+ assert past_key_values is not None, (
956
+ "Past key values should be provided when using cache, e.g. run draft_only=True first."
957
+ )
958
+ assert clean_input_ids is None, (
959
+ "Clean input ids should already be in cache, thus none should be provided."
960
+ )
961
+ clean_len = past_key_values.get_seq_length()
962
+ input_ids = draft_input_ids
963
+ else:
964
+ clean_len = clean_input_ids.shape[1]
965
+ input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
966
+
967
+ per_block_position_ids = torch.arange(
968
+ clean_len, clean_len + block_length + 1, device=draft_input_ids.device
969
+ )[None,].repeat(block_length, 1)
970
+ per_block_position_ids += torch.arange(block_length, device=draft_input_ids.device).view(-1, 1)
971
+
972
+ if use_cache:
973
+ position_ids = per_block_position_ids.view(-1)[None,]
974
+ else:
975
+ clean_position_ids = torch.arange(clean_len, device=draft_input_ids.device)
976
+ position_ids = torch.cat([clean_position_ids, per_block_position_ids.view(-1)], dim=-1)[None,]
977
+
978
+ outputs = self.encoder(
979
+ input_ids=input_ids,
980
+ position_ids=position_ids,
981
+ past_key_values=past_key_values,
982
+ use_cache=use_cache,
983
+ is_training=False,
984
+ )
985
+
986
+ hidden_states = outputs.last_hidden_state
987
+ logits = self.diffusion_head(hidden_states)
988
+ past_key_values = getattr(outputs, "past_key_values", None)
989
+
990
+ if use_cache and past_key_values is not None:
991
+ _extract_draft_kv_cache(past_key_values, clean_len, block_length)
992
+
993
+ return logits, past_key_values
994
+
995
+
996
+ @torch.no_grad()
997
+ def ar_generate(
998
+ self,
999
+ prompt_ids: torch.Tensor,
1000
+ max_new_tokens: int = 128,
1001
+ temperature: float = 0.0,
1002
+ eos_token_id: Optional[int] = None,
1003
+ max_thinking_tokens: Optional[int] = None,
1004
+ end_think_token_id: Optional[int] = None,
1005
+ ) -> tuple:
1006
+ """Autoregressive generation calling the encoder directly (injected by build_hf_tidar_repo).
1007
+
1008
+ Bypasses MinistralDiffEncoderModel.forward() to avoid diffusion-specific
1009
+ code paths. Calls self.encoder (Ministral3Model) with explicit cache_position,
1010
+ position_ids, and use_cache so the KV cache and causal masking behave
1011
+ identically to MistralForCausalLM / vLLM.
1012
+
1013
+ Returns:
1014
+ (output_ids, nfe) where output_ids includes the prompt.
1015
+ """
1016
+ for layer in self.encoder.layers:
1017
+ if hasattr(layer.self_attn, 'diffusion_lm'):
1018
+ layer.self_attn.diffusion_lm = False
1019
+
1020
+ if eos_token_id is None:
1021
+ eos_token_id = getattr(self.config, 'eos_token_id', None)
1022
+
1023
+ device = prompt_ids.device
1024
+ batch_size, prompt_len = prompt_ids.shape
1025
+
1026
+ past_key_values = DynamicCache()
1027
+ cache_position = torch.arange(prompt_len, device=device)
1028
+ position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
1029
+
1030
+ enc_out = self.encoder(
1031
+ input_ids=prompt_ids,
1032
+ position_ids=position_ids,
1033
+ past_key_values=past_key_values,
1034
+ use_cache=True,
1035
+ cache_position=cache_position,
1036
+ )
1037
+ past_key_values = enc_out.past_key_values
1038
+ next_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1039
+
1040
+ generated_tokens = []
1041
+ nfe = 0
1042
+
1043
+ for step in range(max_new_tokens):
1044
+ nfe += 1
1045
+
1046
+ if temperature > 0:
1047
+ probs = torch.softmax(next_logit / temperature, dim=-1)
1048
+ next_token = torch.multinomial(probs, num_samples=1)
1049
+ else:
1050
+ next_token = torch.argmax(next_logit, dim=-1, keepdim=True)
1051
+
1052
+ # ---- thinking budget enforcement ----
1053
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1054
+ if step >= max_thinking_tokens:
1055
+ if generated_tokens:
1056
+ gen_tensor = torch.cat(generated_tokens, dim=1)
1057
+ has_end_think = (gen_tensor == end_think_token_id).any(dim=1)
1058
+ else:
1059
+ has_end_think = torch.zeros(batch_size, dtype=torch.bool, device=device)
1060
+ for b in range(batch_size):
1061
+ if not has_end_think[b]:
1062
+ next_token[b] = end_think_token_id
1063
+
1064
+ generated_tokens.append(next_token)
1065
+
1066
+ if eos_token_id is not None and (next_token == eos_token_id).all():
1067
+ break
1068
+
1069
+ if step < max_new_tokens - 1:
1070
+ cur_pos = prompt_len + step
1071
+ step_cache_pos = torch.tensor([cur_pos], device=device)
1072
+ step_pos_ids = step_cache_pos.unsqueeze(0).expand(batch_size, -1)
1073
+
1074
+ enc_out = self.encoder(
1075
+ input_ids=next_token,
1076
+ position_ids=step_pos_ids,
1077
+ past_key_values=past_key_values,
1078
+ use_cache=True,
1079
+ cache_position=step_cache_pos,
1080
+ )
1081
+ past_key_values = enc_out.past_key_values
1082
+ next_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1083
+
1084
+ all_generated = torch.cat(generated_tokens, dim=1)
1085
+ output_ids = torch.cat([prompt_ids, all_generated], dim=1)
1086
+ return output_ids, nfe
1087
+
1088
+
1089
+ @torch.no_grad()
1090
+ def self_spec_generate(
1091
+ self,
1092
+ prompt_ids: torch.Tensor,
1093
+ max_new_tokens: int = 128,
1094
+ steps: int = 128,
1095
+ block_length: int = 16,
1096
+ ar_mix_weight: Optional[float] = None,
1097
+ temperature: float = 0.0,
1098
+ mask_token_id: Optional[int] = None,
1099
+ eos_token_id: Optional[int] = None,
1100
+ max_thinking_tokens: Optional[int] = None,
1101
+ end_think_token_id: Optional[int] = None,
1102
+ ):
1103
+ self.config.use_sbd_objective = True
1104
+ self.config.dlm_paradigm = "sbd"
1105
+
1106
+ if prompt_ids.shape[0] != 1:
1107
+ raise ValueError("Self speculation quadratic decoding currently requires batch_size == 1")
1108
+
1109
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1110
+ if eos_token_id is None:
1111
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1112
+
1113
+ x = torch.full(
1114
+ (1, prompt_ids.shape[1] + max_new_tokens + block_length * 2),
1115
+ token_mask_id,
1116
+ dtype=torch.long,
1117
+ device=prompt_ids.device,
1118
+ )
1119
+ x[:, : prompt_ids.shape[1]] = prompt_ids.clone()
1120
+
1121
+ if max_new_tokens % block_length != 0:
1122
+ raise ValueError("max_new_tokens must be divisible by block_length")
1123
+ num_blocks = max_new_tokens // block_length
1124
+ if steps % num_blocks != 0:
1125
+ raise ValueError("steps must be divisible by (max_new_tokens // block_length)")
1126
+
1127
+ prompt_len = prompt_ids.shape[1]
1128
+ nfe = 0
1129
+ nfe += 1
1130
+ logits, past_key_values = self.sbd_inference_diffusion_quadratic(
1131
+ clean_input_ids=x[:, :prompt_len],
1132
+ draft_input_ids=x[:, prompt_len : prompt_len + block_length],
1133
+ block_length=block_length,
1134
+ draft_only=True,
1135
+ use_cache=True,
1136
+ )
1137
+
1138
+ logits_proposal = logits[:, prompt_len - 1 : prompt_len + block_length]
1139
+ logits_proposal[:, 1] = logits_proposal[:, 0]
1140
+ logits_proposal = logits_proposal[:, 1:]
1141
+ x0_proposal = torch.argmax(logits_proposal, dim=-1)
1142
+ x[:, prompt_len : prompt_len + block_length] = x0_proposal
1143
+
1144
+ total_accept_token = 0
1145
+ while True:
1146
+ nfe += 1
1147
+ block_start = prompt_len + total_accept_token
1148
+ block_end = block_start + block_length
1149
+ draft_input_ids = x[:, block_start:block_end]
1150
+
1151
+ logits, past_key_values = self.sbd_inference_diffusion_quadratic(
1152
+ clean_input_ids=None,
1153
+ draft_input_ids=draft_input_ids,
1154
+ block_length=block_length,
1155
+ draft_only=False,
1156
+ past_key_values=past_key_values,
1157
+ use_cache=True,
1158
+ )
1159
+
1160
+ useful_token_logits = logits.view(1, block_length, block_length + 1, -1)
1161
+ if ar_mix_weight is None:
1162
+ useful_token_logits[:, :, 1] = useful_token_logits[:, :, 0]
1163
+ else:
1164
+ if not (0.0 <= ar_mix_weight <= 1.0):
1165
+ raise ValueError("ar_mix_weight must be between 0 and 1")
1166
+ mix_logits = useful_token_logits[:, :, 0] * ar_mix_weight + useful_token_logits[:, :, 1] * (1 - ar_mix_weight)
1167
+ useful_token_logits[:, :, 0] = mix_logits
1168
+ useful_token_logits[:, :, 1] = mix_logits
1169
+
1170
+ if temperature > 0:
1171
+ useful_token_logits = useful_token_logits / temperature
1172
+
1173
+ useful_token_pred = torch.argmax(useful_token_logits, dim=-1)
1174
+ new_draft_input_ids = useful_token_pred[:, 0, 1:]
1175
+ accept_cnt = 1
1176
+
1177
+ while accept_cnt < block_length:
1178
+ if useful_token_pred[:, accept_cnt - 1, 0].item() != draft_input_ids[:, accept_cnt].item():
1179
+ break
1180
+ new_draft_input_ids = useful_token_pred[:, accept_cnt, 1:]
1181
+ accept_cnt += 1
1182
+
1183
+ x[:, block_start : block_start + accept_cnt] = draft_input_ids[:, :accept_cnt]
1184
+
1185
+ # EoS early stopping: all accepted tokens are finalized left-to-right,
1186
+ # so if any is EoS we can truncate and return immediately.
1187
+ if eos_token_id is not None:
1188
+ accepted = x[0, block_start : block_start + accept_cnt]
1189
+ eos_positions = (accepted == eos_token_id).nonzero(as_tuple=True)[0]
1190
+ if len(eos_positions) > 0:
1191
+ first_eos_rel = eos_positions[0].item()
1192
+ total_accept_token += first_eos_rel + 1
1193
+ output_end = prompt_len + total_accept_token
1194
+ return x[:, :output_end], nfe
1195
+
1196
+ x[:, block_start + accept_cnt : block_start + accept_cnt + block_length] = new_draft_input_ids
1197
+ past_key_values.crop(block_start + accept_cnt)
1198
+
1199
+ # ---- thinking budget enforcement ----
1200
+ # Insert end_think as the first token of the next draft block,
1201
+ # shifting all subsequent tokens right by 1 (discarding the last).
1202
+ # The first draft token is always accepted unconditionally, so
1203
+ # end_think is guaranteed to be finalized in the next iteration
1204
+ # without needing to re-encode or touch the KV cache.
1205
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1206
+ tokens_so_far = total_accept_token + accept_cnt
1207
+ if tokens_so_far > max_thinking_tokens:
1208
+ gen_so_far = x[0, prompt_len : prompt_len + tokens_so_far]
1209
+ has_end_think = (gen_so_far == end_think_token_id).any()
1210
+ if not has_end_think:
1211
+ insert_pos = block_start + accept_cnt
1212
+ x[0, insert_pos + 1:] = x[0, insert_pos:-1].clone()
1213
+ x[0, insert_pos] = end_think_token_id
1214
+
1215
+ total_accept_token += accept_cnt
1216
+
1217
+ if total_accept_token >= max_new_tokens:
1218
+ break
1219
+
1220
+ return x[:, : -(block_length * 2)], nfe
1221
+
1222
+
1223
+ @torch.no_grad()
1224
+ def linear_spec_generate(
1225
+ self,
1226
+ prompt_ids: torch.Tensor,
1227
+ max_new_tokens: int = 128,
1228
+ block_length: int = 32,
1229
+ temperature: float = 0.0,
1230
+ mask_token_id: Optional[int] = None,
1231
+ eos_token_id: Optional[int] = None,
1232
+ max_thinking_tokens: Optional[int] = None,
1233
+ end_think_token_id: Optional[int] = None,
1234
+ threshold: float = 0.0,
1235
+ ):
1236
+ """Linear speculative decoding: diffusion draft + AR verification.
1237
+
1238
+ Each step:
1239
+ 1. Draft: forward [last_accepted, mask, ...] with bidirectional attention
1240
+ (diffusion_lm=True, use_cache=False). Shift AR logits to get
1241
+ per-position predictions; apply confidence filtering.
1242
+ 2. Verify: forward the drafted block with causal attention
1243
+ (diffusion_lm=False, use_cache=True, use_causal_mask=True).
1244
+ Accept consecutive AR-matching tokens plus one bonus token.
1245
+
1246
+ Args:
1247
+ prompt_ids: Input token IDs of shape (1, prompt_len).
1248
+ max_new_tokens: Maximum number of tokens to generate.
1249
+ block_length: Number of tokens per draft/verify block.
1250
+ temperature: Sampling temperature (0 = greedy).
1251
+ mask_token_id: Override for config.mask_token_id.
1252
+ eos_token_id: Override for config.eos_token_id.
1253
+ max_thinking_tokens: Budget for thinking tokens before forcing end_think.
1254
+ end_think_token_id: Token ID inserted when thinking budget is exceeded.
1255
+ threshold: Confidence threshold for accepting draft predictions.
1256
+
1257
+ Returns:
1258
+ (output_ids, nfe): output_ids includes the prompt; nfe is the number
1259
+ of forward evaluations (matching self_spec_generate interface).
1260
+ """
1261
+ if prompt_ids.shape[0] != 1:
1262
+ raise ValueError("Linear speculative decoding requires batch_size == 1")
1263
+
1264
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1265
+ if eos_token_id is None:
1266
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1267
+
1268
+ device = prompt_ids.device
1269
+ prompt_len = prompt_ids.shape[1]
1270
+ dream_style = getattr(self.config, 'dlm_type', 'llada') == 'dream'
1271
+
1272
+ def _set_diffusion_lm(val: bool):
1273
+ for layer in self.encoder.layers:
1274
+ if hasattr(layer.self_attn, 'diffusion_lm'):
1275
+ layer.self_attn.diffusion_lm = val
1276
+
1277
+ # ===== Prefill (causal) =====
1278
+ _set_diffusion_lm(False)
1279
+
1280
+ enc_out = self.encoder(
1281
+ input_ids=prompt_ids,
1282
+ past_key_values=DynamicCache(),
1283
+ use_cache=True,
1284
+ use_causal_mask=True,
1285
+ )
1286
+ past_key_values = enc_out.past_key_values
1287
+ last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1288
+ nfe = 1
1289
+
1290
+ if temperature > 0:
1291
+ probs = torch.softmax(last_logit / temperature, dim=-1)
1292
+ next_token = torch.multinomial(probs, num_samples=1)
1293
+ else:
1294
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
1295
+
1296
+ if eos_token_id is not None and next_token.item() == eos_token_id:
1297
+ output_ids = torch.cat([prompt_ids, next_token], dim=1)
1298
+ return output_ids, nfe
1299
+
1300
+ generated = [next_token]
1301
+ total_gen = 1
1302
+
1303
+ # ===== Main loop =====
1304
+ while total_gen < max_new_tokens:
1305
+ cache_len = past_key_values.get_seq_length()
1306
+
1307
+ block = torch.full(
1308
+ (1, block_length), token_mask_id, dtype=torch.long, device=device
1309
+ )
1310
+ block[0, 0] = next_token.item()
1311
+
1312
+ # -------- Draft (bidirectional, don't update cache) --------
1313
+ _set_diffusion_lm(True)
1314
+ while True:
1315
+ is_mask = block == token_mask_id
1316
+ if not is_mask.any():
1317
+ break
1318
+
1319
+ enc_out = self.encoder(
1320
+ input_ids=block,
1321
+ past_key_values=past_key_values,
1322
+ use_cache=False,
1323
+ )
1324
+ nfe += 1
1325
+
1326
+ draft_logits = self.diffusion_head(enc_out.last_hidden_state)
1327
+ if dream_style:
1328
+ # DREAM: logit[i] predicts position i+1 → shift to self-prediction
1329
+ draft_logits = torch.cat(
1330
+ [draft_logits[:, :1, :], draft_logits[:, :-1, :]], dim=1
1331
+ )
1332
+ # LLaDA: logit[i] already predicts position i → no shift needed
1333
+
1334
+ if temperature > 0:
1335
+ draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
1336
+ draft_tokens = torch.multinomial(
1337
+ draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1
1338
+ ).view(1, block_length)
1339
+ else:
1340
+ draft_tokens = draft_logits.argmax(dim=-1)
1341
+ draft_probs = torch.softmax(draft_logits, dim=-1)
1342
+
1343
+ if threshold > 0:
1344
+ draft_conf = torch.gather(
1345
+ draft_probs, -1, draft_tokens.unsqueeze(-1)
1346
+ ).squeeze(-1)
1347
+ draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
1348
+ unmask = draft_conf >= threshold
1349
+
1350
+ # Ensure each iteration makes progress even when every masked
1351
+ # position falls below the confidence threshold.
1352
+ if not unmask.any():
1353
+ best_idx = draft_conf.view(-1).argmax()
1354
+ unmask = torch.zeros_like(is_mask, dtype=torch.bool)
1355
+ unmask.view(-1)[best_idx] = True
1356
+
1357
+ block[unmask] = draft_tokens[unmask]
1358
+ else:
1359
+ block[is_mask] = draft_tokens[is_mask]
1360
+ break
1361
+
1362
+ # -------- Verify (causal, update cache) --------
1363
+ _set_diffusion_lm(False)
1364
+ enc_out = self.encoder(
1365
+ input_ids=block,
1366
+ past_key_values=past_key_values,
1367
+ use_cache=True,
1368
+ use_causal_mask=True,
1369
+ )
1370
+ past_key_values = enc_out.past_key_values
1371
+ nfe += 1
1372
+
1373
+ verify_logits = self.diffusion_head(enc_out.last_hidden_state)
1374
+ if temperature > 0:
1375
+ verify_probs = torch.softmax(verify_logits / temperature, dim=-1)
1376
+ ar_tokens = torch.multinomial(
1377
+ verify_probs.view(-1, verify_probs.shape[-1]), num_samples=1
1378
+ ).view(1, block_length)
1379
+ else:
1380
+ ar_tokens = verify_logits.argmax(dim=-1)
1381
+
1382
+ accepted = 0
1383
+ for i in range(block_length - 1):
1384
+ if ar_tokens[0, i].item() == block[0, i + 1].item():
1385
+ accepted += 1
1386
+ else:
1387
+ break
1388
+ accepted += 1 # bonus token from AR verification
1389
+
1390
+ accepted_toks = ar_tokens[:, :accepted]
1391
+ generated.append(accepted_toks)
1392
+ total_gen += accepted
1393
+
1394
+ _crop_dynamic_cache(past_key_values, cache_len + accepted)
1395
+
1396
+ next_token = ar_tokens[:, accepted - 1 : accepted]
1397
+
1398
+ # -------- EOS check --------
1399
+ if eos_token_id is not None:
1400
+ eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
1401
+ if len(eos_pos) > 0:
1402
+ first_eos = eos_pos[0].item()
1403
+ generated[-1] = accepted_toks[:, : first_eos + 1]
1404
+ total_gen = total_gen - accepted + first_eos + 1
1405
+ break
1406
+
1407
+ # -------- Thinking budget enforcement --------
1408
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1409
+ if total_gen > max_thinking_tokens:
1410
+ all_gen = torch.cat(generated, dim=1)
1411
+ if not (all_gen == end_think_token_id).any():
1412
+ next_token = torch.tensor(
1413
+ [[end_think_token_id]], device=device
1414
+ )
1415
+
1416
+ if total_gen >= max_new_tokens:
1417
+ break
1418
+
1419
+ all_generated = torch.cat(generated, dim=1)
1420
+ output_ids = torch.cat([prompt_ids, all_generated], dim=1)
1421
+
1422
+ return output_ids, nfe
1423
+
1424
+
1425
+ @torch.no_grad()
1426
+ def linear_spec_generate_mp(
1427
+ self,
1428
+ prompt_ids: torch.Tensor,
1429
+ max_new_tokens: int = 512,
1430
+ block_length: int = 32,
1431
+ temperature: float = 0.0,
1432
+ mask_token_id: Optional[int] = None,
1433
+ eos_token_id: Optional[int] = None,
1434
+ max_paths: int = 16,
1435
+ uncertain_threshold: float = 0.7,
1436
+ top_k_candidates: int = 2,
1437
+ threshold: float = 0.0,
1438
+ max_thinking_tokens: Optional[int] = None,
1439
+ end_think_token_id: Optional[int] = None,
1440
+ ):
1441
+ """Linear speculative decoding with multi-path tree verification.
1442
+
1443
+ Self-contained method — no external file dependencies beyond the model itself.
1444
+
1445
+ Each iteration costs 2 NFE (1 draft + 1 verify):
1446
+ 1. Draft: single-step bidirectional diffusion fills a block of masks.
1447
+ 2. Verify: tree-structured AR verification with multiple candidate paths.
1448
+
1449
+ Multi-path verification identifies low-confidence draft positions and
1450
+ explores top-k alternative tokens. All candidate paths share a trie
1451
+ prefix and are verified in one forward pass via a 4D tree-ancestry
1452
+ attention mask (~40 tokens), picking the path with the longest
1453
+ accepted prefix.
1454
+
1455
+ Benchmark results (NeMo Skills prompt, enable_thinking=False):
1456
+ GSM8K bl=32: +17.1% UW-TPF vs vanilla (acc 93.9%)
1457
+ MBPP bl=64: +17.8% UW-TPF vs vanilla (pass@1 78.2%)
1458
+
1459
+ Args:
1460
+ prompt_ids: (1, prompt_len) input token IDs.
1461
+ max_new_tokens: Maximum tokens to generate.
1462
+ block_length: Draft block size. Use 32 for math, 64 for code.
1463
+ temperature: Sampling temperature (0.0 = greedy).
1464
+ eos_token_id: Stop token ID.
1465
+ max_paths: Tree verification budget. 16 = up to 4 uncertain
1466
+ positions x 2 candidates each.
1467
+ uncertain_threshold: Confidence below which a position is
1468
+ considered uncertain and expanded with alternatives.
1469
+ top_k_candidates: Number of alternative tokens to try at each
1470
+ uncertain position.
1471
+
1472
+ Returns:
1473
+ output_ids: (1, prompt_len + generated_len) full sequence.
1474
+ nfe: Total number of forward evaluations.
1475
+ """
1476
+ from itertools import product as _product
1477
+
1478
+ if prompt_ids.shape[0] != 1:
1479
+ raise ValueError("Requires batch_size == 1")
1480
+
1481
+ device = prompt_ids.device
1482
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1483
+ if eos_token_id is None:
1484
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1485
+
1486
+ def _set_dlm(val: bool):
1487
+ for layer in self.encoder.layers:
1488
+ if hasattr(layer.self_attn, 'diffusion_lm'):
1489
+ layer.self_attn.diffusion_lm = val
1490
+
1491
+ def _crop_cache(kv, length):
1492
+ # transformers 4.55 exposes .key_cache/.value_cache lists; 5.0 moved them under .layers[i].keys/.values.
1493
+ for li in range(len(kv)):
1494
+ if hasattr(kv, 'layers'):
1495
+ layer = kv.layers[li]
1496
+ layer.keys = layer.keys[:, :, :length]
1497
+ layer.values = layer.values[:, :, :length]
1498
+ else:
1499
+ kv.key_cache[li] = kv.key_cache[li][:, :, :length]
1500
+ kv.value_cache[li] = kv.value_cache[li][:, :, :length]
1501
+ kv._seen_tokens = length
1502
+
1503
+ # ----- tree verify helpers (inlined) -----
1504
+
1505
+ def _mp_verify(block, draft_probs, draft_conf, past_kv, cache_len):
1506
+ """Multi-path verify via batch-stacking (flash-attention compatible).
1507
+
1508
+ Unlike tree attention (4D mask), batch-stacking expands the KV cache
1509
+ batch dimension and runs all candidate paths as separate batch entries.
1510
+ This keeps flash attention + GQA enabled, avoiding OOM from the 4D
1511
+ mask path which disables both.
1512
+
1513
+ Returns (accepted_toks, n_accepted, past_kv, next_tok) or None.
1514
+ """
1515
+ bl = block.shape[1]
1516
+
1517
+ # Identify uncertain positions
1518
+ is_filled = block[0] != token_mask_id
1519
+ pos_conf = torch.zeros(bl, device=device)
1520
+ pos_conf[0] = float('inf')
1521
+ for p in range(1, bl):
1522
+ if is_filled[p]:
1523
+ c = draft_conf[0, p].item()
1524
+ pos_conf[p] = c if c != float('-inf') else float('inf')
1525
+ else:
1526
+ pos_conf[p] = float('-inf')
1527
+
1528
+ unc_mask = (pos_conf < uncertain_threshold) & (pos_conf > float('-inf'))
1529
+ unc_pos = unc_mask.nonzero(as_tuple=True)[0].tolist()
1530
+ if not unc_pos:
1531
+ return None
1532
+
1533
+ import math as _math
1534
+ max_unc = min(len(unc_pos), max(1, int(_math.log2(max_paths))))
1535
+ unc_pos = sorted(unc_pos)[:max_unc]
1536
+
1537
+ # Build candidate blocks
1538
+ topk_at = {}
1539
+ for p in unc_pos:
1540
+ _, ids = draft_probs[0, p].topk(top_k_candidates)
1541
+ topk_at[p] = ids.tolist()
1542
+
1543
+ combos = list(_product(*(topk_at[p] for p in sorted(topk_at))))[:max_paths]
1544
+ num_paths = len(combos)
1545
+ if num_paths <= 1:
1546
+ return None
1547
+
1548
+ candidate_blocks = block.expand(num_paths, -1).clone()
1549
+ pos_list = sorted(topk_at.keys())
1550
+ for pi, combo in enumerate(combos):
1551
+ for ci, p in enumerate(pos_list):
1552
+ candidate_blocks[pi, p] = combo[ci]
1553
+
1554
+ # Expand KV cache batch dimension (shared, no copy)
1555
+ for li in range(len(past_kv)):
1556
+ if hasattr(past_kv, 'layers'):
1557
+ layer = past_kv.layers[li]
1558
+ layer.keys = layer.keys.expand(num_paths, -1, -1, -1)
1559
+ layer.values = layer.values.expand(num_paths, -1, -1, -1)
1560
+ else:
1561
+ past_kv.key_cache[li] = past_kv.key_cache[li].expand(num_paths, -1, -1, -1)
1562
+ past_kv.value_cache[li] = past_kv.value_cache[li].expand(num_paths, -1, -1, -1)
1563
+
1564
+ # Batched causal verify — uses flash attention + GQA
1565
+ _set_dlm(False)
1566
+ enc_out = self.encoder(
1567
+ input_ids=candidate_blocks,
1568
+ past_key_values=past_kv,
1569
+ use_cache=True,
1570
+ use_causal_mask=True,
1571
+ )
1572
+ past_kv = enc_out.past_key_values
1573
+ vlogits = self.diffusion_head(enc_out.last_hidden_state)
1574
+
1575
+ if temperature > 0:
1576
+ vp = torch.softmax(vlogits / temperature, dim=-1)
1577
+ ar_tokens = torch.multinomial(vp.view(-1, vp.shape[-1]), 1).view(num_paths, bl)
1578
+ else:
1579
+ ar_tokens = vlogits.argmax(dim=-1)
1580
+
1581
+ # Find best path (longest accepted prefix)
1582
+ best_acc, best_pidx = 0, 0
1583
+ for pi in range(num_paths):
1584
+ acc = 0
1585
+ for i in range(bl - 1):
1586
+ if ar_tokens[pi, i].item() == candidate_blocks[pi, i + 1].item():
1587
+ acc += 1
1588
+ else:
1589
+ break
1590
+ acc += 1
1591
+ if acc > best_acc:
1592
+ best_acc, best_pidx = acc, pi
1593
+
1594
+ accepted_toks = ar_tokens[best_pidx:best_pidx+1, :best_acc]
1595
+
1596
+ # Extract winning path's KV cache slice
1597
+ for li in range(len(past_kv)):
1598
+ if hasattr(past_kv, 'layers'):
1599
+ layer = past_kv.layers[li]
1600
+ layer.keys = layer.keys[best_pidx:best_pidx+1].contiguous()
1601
+ layer.values = layer.values[best_pidx:best_pidx+1].contiguous()
1602
+ else:
1603
+ past_kv.key_cache[li] = past_kv.key_cache[li][best_pidx:best_pidx+1].contiguous()
1604
+ past_kv.value_cache[li] = past_kv.value_cache[li][best_pidx:best_pidx+1].contiguous()
1605
+ _crop_cache(past_kv, cache_len + best_acc)
1606
+
1607
+ return accepted_toks, best_acc, past_kv, accepted_toks[:, -1:]
1608
+
1609
+ # ── Prefill (causal) ──
1610
+ _set_dlm(False)
1611
+ enc_out = self.encoder(
1612
+ input_ids=prompt_ids, past_key_values=DynamicCache(),
1613
+ use_cache=True, use_causal_mask=True,
1614
+ )
1615
+ past_key_values = enc_out.past_key_values
1616
+ last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1617
+ nfe = 1
1618
+
1619
+ if temperature > 0:
1620
+ next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), 1)
1621
+ else:
1622
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
1623
+
1624
+ if eos_token_id is not None and next_token.item() == eos_token_id:
1625
+ return torch.cat([prompt_ids, next_token], dim=1), nfe
1626
+
1627
+ generated = [next_token]
1628
+ total_gen = 1
1629
+
1630
+ # ── Main draft-verify loop ──
1631
+ while total_gen < max_new_tokens:
1632
+ cache_len = past_key_values.get_seq_length()
1633
+
1634
+ block = torch.full((1, block_length), token_mask_id, dtype=torch.long, device=device)
1635
+ block[0, 0] = next_token.item()
1636
+
1637
+ # Draft: single-step bidirectional diffusion (1 NFE)
1638
+ _set_dlm(True)
1639
+ enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=False)
1640
+ nfe += 1
1641
+
1642
+ draft_logits = self.diffusion_head(enc_out.last_hidden_state)
1643
+ if temperature > 0:
1644
+ draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
1645
+ draft_tokens = torch.multinomial(
1646
+ draft_probs.view(-1, draft_probs.shape[-1]), 1
1647
+ ).view(1, block_length)
1648
+ else:
1649
+ draft_tokens = draft_logits.argmax(dim=-1)
1650
+ draft_probs = torch.softmax(draft_logits, dim=-1)
1651
+
1652
+ draft_conf = torch.gather(draft_probs, -1, draft_tokens.unsqueeze(-1)).squeeze(-1)
1653
+ is_mask = block == token_mask_id
1654
+ draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
1655
+ block[is_mask] = draft_tokens[is_mask]
1656
+
1657
+ # Verify: multi-path batch-stacking (1 NFE, flash-attention compatible)
1658
+ result = _mp_verify(block, draft_probs, draft_conf, past_key_values, cache_len)
1659
+
1660
+ if result is not None:
1661
+ accepted_toks, accepted, past_key_values, next_token = result
1662
+ nfe += 1
1663
+ else:
1664
+ # No uncertain positions — single-path causal verify
1665
+ _set_dlm(False)
1666
+ enc_out = self.encoder(
1667
+ input_ids=block, past_key_values=past_key_values,
1668
+ use_cache=True, use_causal_mask=True,
1669
+ )
1670
+ past_key_values = enc_out.past_key_values
1671
+ nfe += 1
1672
+
1673
+ vlogits = self.diffusion_head(enc_out.last_hidden_state)
1674
+ if temperature > 0:
1675
+ vp = torch.softmax(vlogits / temperature, dim=-1)
1676
+ ar_tokens = torch.multinomial(vp.view(-1, vp.shape[-1]), 1).view(1, block_length)
1677
+ else:
1678
+ ar_tokens = vlogits.argmax(dim=-1)
1679
+
1680
+ accepted = 0
1681
+ for i in range(block_length - 1):
1682
+ if ar_tokens[0, i].item() == block[0, i + 1].item():
1683
+ accepted += 1
1684
+ else:
1685
+ break
1686
+ accepted += 1
1687
+
1688
+ accepted_toks = ar_tokens[:, :accepted]
1689
+ _crop_cache(past_key_values, cache_len + accepted)
1690
+ next_token = ar_tokens[:, accepted - 1 : accepted]
1691
+
1692
+ generated.append(accepted_toks)
1693
+ total_gen += accepted
1694
+
1695
+ if eos_token_id is not None:
1696
+ eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
1697
+ if len(eos_pos) > 0:
1698
+ first_eos = eos_pos[0].item()
1699
+ generated[-1] = accepted_toks[:, :first_eos + 1]
1700
+ total_gen = total_gen - accepted + first_eos + 1
1701
+ break
1702
+
1703
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1704
+ if total_gen > max_thinking_tokens:
1705
+ all_gen = torch.cat(generated, dim=1)
1706
+ if not (all_gen == end_think_token_id).any():
1707
+ next_token = torch.tensor(
1708
+ [[end_think_token_id]], device=device
1709
+ )
1710
+
1711
+ if total_gen >= max_new_tokens:
1712
+ break
1713
+
1714
+ all_generated = torch.cat(generated, dim=1)
1715
+ output_ids = torch.cat([prompt_ids, all_generated], dim=1)
1716
+ return output_ids, nfe
1717
+
1718
+
1719
+ @torch.no_grad()
1720
+ def linear_spec_generate_lora(
1721
+ self,
1722
+ prompt_ids: torch.Tensor,
1723
+ max_new_tokens: int = 128,
1724
+ block_length: int = 32,
1725
+ temperature: float = 0.0,
1726
+ mask_token_id: Optional[int] = None,
1727
+ eos_token_id: Optional[int] = None,
1728
+ threshold: float = 0.0,
1729
+ rebuild_kv: str = 'none',
1730
+ max_thinking_tokens: Optional[int] = None,
1731
+ end_think_token_id: Optional[int] = None,
1732
+ ):
1733
+ """Linear speculative decoding: diffusion draft + AR verify.
1734
+ LoRA adapter toggling: ON for draft (bidirectional), OFF for verify (causal).
1735
+ Returns (output_ids, nfe).
1736
+ """
1737
+ if prompt_ids.shape[0] != 1:
1738
+ raise ValueError("linear_spec_generate requires batch_size == 1")
1739
+
1740
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1741
+ if eos_token_id is None:
1742
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1743
+
1744
+ device = prompt_ids.device
1745
+ dream_style = getattr(self.config, 'dlm_type', 'llada') == 'dream'
1746
+
1747
+ def _set_diffusion_lm(val: bool):
1748
+ for layer in self.encoder.layers:
1749
+ if hasattr(layer.self_attn, 'diffusion_lm'):
1750
+ layer.self_attn.diffusion_lm = val
1751
+
1752
+ def _toggle_adapters(model, enable: bool):
1753
+ for module in model.modules():
1754
+ if hasattr(module, '_disable_adapters'):
1755
+ module._disable_adapters = not enable
1756
+
1757
+ # Prefill (causal, LoRA OFF)
1758
+ _set_diffusion_lm(False)
1759
+ _toggle_adapters(self, False)
1760
+ enc_out = self.encoder(
1761
+ input_ids=prompt_ids,
1762
+ past_key_values=DynamicCache(),
1763
+ use_cache=True,
1764
+ use_causal_mask=True,
1765
+ )
1766
+ past_key_values = enc_out.past_key_values
1767
+ last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1768
+ nfe = 1
1769
+
1770
+ if temperature > 0:
1771
+ next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), num_samples=1)
1772
+ else:
1773
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
1774
+
1775
+ if eos_token_id is not None and next_token.item() == eos_token_id:
1776
+ return torch.cat([prompt_ids, next_token], dim=1), nfe
1777
+
1778
+ generated = [next_token]
1779
+ total_gen = 1
1780
+
1781
+ while total_gen < max_new_tokens:
1782
+ cache_len = past_key_values.get_seq_length()
1783
+
1784
+ block = torch.full((1, block_length), token_mask_id, dtype=torch.long, device=device)
1785
+ block[0, 0] = next_token.item()
1786
+
1787
+ # Draft (bidirectional, LoRA ON)
1788
+ _set_diffusion_lm(True)
1789
+ _toggle_adapters(self, True)
1790
+ enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=False)
1791
+ nfe += 1
1792
+
1793
+ draft_logits = self.diffusion_head(enc_out.last_hidden_state)
1794
+ if dream_style:
1795
+ draft_logits = torch.cat([draft_logits[:, :1, :], draft_logits[:, :-1, :]], dim=1)
1796
+
1797
+ if temperature > 0:
1798
+ draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
1799
+ draft_tokens = torch.multinomial(draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1).view(1, block_length)
1800
+ else:
1801
+ draft_tokens = draft_logits.argmax(dim=-1)
1802
+ draft_probs = torch.softmax(draft_logits, dim=-1)
1803
+
1804
+ draft_conf = torch.gather(draft_probs, -1, draft_tokens.unsqueeze(-1)).squeeze(-1)
1805
+ is_mask = block == token_mask_id
1806
+ draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
1807
+ unmask = draft_conf > threshold
1808
+ if unmask.sum() > 0:
1809
+ block[unmask] = draft_tokens[unmask]
1810
+
1811
+ # Verify (causal, LoRA OFF)
1812
+ _set_diffusion_lm(False)
1813
+ _toggle_adapters(self, False)
1814
+ enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=True, use_causal_mask=True)
1815
+ past_key_values = enc_out.past_key_values
1816
+ nfe += 1
1817
+
1818
+ verify_logits = self.diffusion_head(enc_out.last_hidden_state)
1819
+ if temperature > 0:
1820
+ ar_tokens = torch.multinomial(torch.softmax(verify_logits / temperature, dim=-1).view(-1, verify_logits.shape[-1]), num_samples=1).view(1, block_length)
1821
+ else:
1822
+ ar_tokens = verify_logits.argmax(dim=-1)
1823
+
1824
+ accepted = 0
1825
+ for i in range(block_length - 1):
1826
+ if ar_tokens[0, i].item() == block[0, i + 1].item():
1827
+ accepted += 1
1828
+ else:
1829
+ break
1830
+ accepted += 1 # bonus token
1831
+
1832
+ accepted_toks = ar_tokens[:, :accepted]
1833
+ generated.append(accepted_toks)
1834
+ total_gen += accepted
1835
+
1836
+ _crop_dynamic_cache(past_key_values, cache_len + accepted)
1837
+ next_token = ar_tokens[:, accepted - 1 : accepted]
1838
+
1839
+ # EOS check
1840
+ if eos_token_id is not None:
1841
+ eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
1842
+ if len(eos_pos) > 0:
1843
+ first_eos = eos_pos[0].item()
1844
+ generated[-1] = accepted_toks[:, : first_eos + 1]
1845
+ total_gen = total_gen - accepted + first_eos + 1
1846
+ break
1847
+
1848
+ # Thinking budget enforcement
1849
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1850
+ if total_gen > max_thinking_tokens:
1851
+ all_gen = torch.cat(generated, dim=1)
1852
+ if not (all_gen == end_think_token_id).any():
1853
+ next_token = torch.tensor([[end_think_token_id]], device=device)
1854
+
1855
+ if total_gen >= max_new_tokens:
1856
+ break
1857
+
1858
+ all_generated = torch.cat(generated, dim=1)
1859
+ output_ids = torch.cat([prompt_ids, all_generated], dim=1)
1860
+ return output_ids, nfe
modeling_nemotron_labs_diffusion.py DELETED
@@ -1,870 +0,0 @@
1
- import copy
2
- from dataclasses import dataclass
3
- from typing import Optional, Tuple
4
- import numpy as np
5
-
6
- import torch
7
- import torch.nn.functional as F
8
- from torch import nn
9
- from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
10
- from transformers.utils import ModelOutput
11
-
12
- from torch.nn.attention.flex_attention import flex_attention, create_block_mask
13
-
14
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
15
-
16
- from transformers.processing_utils import Unpack
17
-
18
- from transformers.cache_utils import Cache, DynamicCache
19
-
20
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
21
-
22
- from transformers.generation import GenerationMixin
23
-
24
- import math
25
-
26
- from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
27
- from .configuration_nemotron_labs_diffusion import NemotronLabsDiffusionConfig
28
-
29
- __all__ = ["NemotronLabsDiffusionModel", "NemotronLabsDiffusionFlexAttention"]
30
-
31
- @dataclass
32
- class NemotronLabsDiffusionOutputWithPast(ModelOutput):
33
- loss: torch.FloatTensor | None = None
34
- logits: torch.FloatTensor | None = None
35
- causal_logits: torch.FloatTensor | None = None
36
- past_key_values: Cache | None = None
37
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
38
- attentions: tuple[torch.FloatTensor, ...] | None = None
39
-
40
-
41
- @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs", dynamic=False)
42
- def fused_flex_attention(q, k, v, block_mask=None):
43
- return flex_attention(q, k, v, block_mask=block_mask)
44
-
45
-
46
- class NemotronLabsDiffusionFlexAttention(Ministral3Attention):
47
- def __init__(self, *args, **kwargs):
48
- super().__init__(*args, **kwargs)
49
-
50
- self.block_size = self.config.block_size
51
- self.block_diff_mask = None
52
-
53
- import torch._dynamo.config as dcfg
54
- dcfg.cache_size_limit = 512
55
-
56
- def compute_block_mask(self, mode, q_len, block_size=None):
57
-
58
- def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
59
- x0_flag_q = (q_idx >= n)
60
- x0_flag_kv = (kv_idx >= n)
61
-
62
- # Compute block indices
63
- block_q = torch.where(x0_flag_q == 1,
64
- (q_idx - n) // block_size,
65
- q_idx // block_size)
66
- block_kv = torch.where(x0_flag_kv == 1,
67
- (kv_idx - n) // block_size,
68
- kv_idx // block_size)
69
-
70
- # **1. Block Diagonal Mask (M_BD) **
71
- block_diagonal = (block_q == block_kv) & (x0_flag_kv == 0) & (x0_flag_q == 0)
72
-
73
- # **2. Offset Block-Causal Mask (M_OBC) **
74
- offset_block_causal = (
75
- (block_q > block_kv)
76
- & (x0_flag_kv == 1)
77
- & (x0_flag_q == 0)
78
- )
79
-
80
- # **3. Fully Causal Mask (M_BC) **
81
- fully_causal = (q_idx >= kv_idx) & (x0_flag_kv == 1) & (x0_flag_q == 1)
82
-
83
- # **4. Combine Masks **
84
- return block_diagonal | offset_block_causal | fully_causal
85
-
86
- attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, q_len//2)
87
-
88
- block_mask = create_block_mask(
89
- attn_mask, B=None, H=None, Q_LEN=q_len, KV_LEN=q_len
90
- )
91
-
92
- return block_mask
93
-
94
-
95
- def forward(
96
- self,
97
- hidden_states: torch.Tensor,
98
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
99
- attention_mask: Optional[torch.Tensor],
100
- past_key_values: Optional[Cache] = None,
101
- cache_position: Optional[torch.LongTensor] = None,
102
- is_training: bool = True,
103
- **kwargs: Unpack[FlashAttentionKwargs],
104
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
105
- bsz, q_len, _ = hidden_states.size()
106
- input_shape = hidden_states.shape[:-1]
107
- hidden_shape = (*input_shape, -1, self.head_dim)
108
-
109
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
110
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
111
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
112
-
113
- cos, sin = position_embeddings
114
-
115
- if is_training:
116
- # Split query and key states in half along sequence length dimension
117
- q1, q2 = query_states.chunk(2, dim=2)
118
- k1, k2 = key_states.chunk(2, dim=2)
119
-
120
- # Apply RoPE independently to each half
121
- q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
122
- q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
123
-
124
- # Recombine the halves
125
- query_states = torch.cat([q1, q2], dim=2)
126
- key_states = torch.cat([k1, k2], dim=2)
127
- else:
128
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
129
-
130
- query_states = query_states * _get_llama_4_attn_scale(
131
- cache_position,
132
- self.config.rope_parameters.get("llama_4_scaling_beta"),
133
- self.config.rope_parameters.get("original_max_position_embeddings"),
134
- ).to(query_states.dtype)
135
-
136
- if past_key_values is not None:
137
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
138
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
139
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
140
-
141
- key_states = repeat_kv(key_states, self.num_key_value_groups)
142
- value_states = repeat_kv(value_states, self.num_key_value_groups)
143
-
144
- if self.block_diff_mask is None or q_len != self.block_diff_mask.shape[-2]:
145
- block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
146
- else:
147
- block_mask = self.block_diff_mask
148
-
149
- attn_output = fused_flex_attention(query_states, key_states, value_states, block_mask=block_mask)
150
- attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
151
-
152
- attn_output = self.o_proj(attn_output)
153
-
154
- return attn_output, None
155
-
156
-
157
- class NemotronLabsDiffusionModel(Ministral3PreTrainedModel, GenerationMixin):
158
- """
159
- A single model with:
160
- - a bidirectional encoder + diffusion‐LM head over A
161
- - a causal decoder + LM head over B, conditioned on F_A
162
- """
163
-
164
- def __init__(self, config: NemotronLabsDiffusionConfig):
165
- super().__init__(config)
166
-
167
- self.mask_token_id = config.mask_token_id
168
-
169
- diffusion_config = copy.deepcopy(config)
170
- diffusion_config.diffusion_lm = True
171
-
172
- if config.dlm_paradigm == 'block_diff':
173
- diffusion_config.attn_class = NemotronLabsDiffusionFlexAttention
174
- elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
175
- diffusion_config.attn_class = Ministral3Attention
176
- if config.dlm_paradigm == 'autoregressive':
177
- diffusion_config.diffusion_lm = False
178
- else:
179
- raise ValueError(f"Unsupported DLM paradigm: {config.dlm_paradigm}")
180
-
181
- self.encoder = Ministral3Model(diffusion_config)
182
- self.diffusion_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
183
- self.vocab_size = config.vocab_size
184
-
185
- self.post_init()
186
-
187
-
188
- def get_input_embeddings(self):
189
- return self.encoder.embed_tokens
190
-
191
- def set_input_embeddings(self, value):
192
- self.encoder.embed_tokens = value
193
-
194
- def get_output_embeddings(self):
195
- return self.diffusion_head
196
-
197
- def set_output_embeddings(self, new_embeddings):
198
- self.diffusion_head = new_embeddings
199
-
200
-
201
- def forward_process(self, input_ids, eps=1e-3, block_size=None, loss_mask=None):
202
- b, l = input_ids.shape
203
- device = input_ids.device
204
-
205
- if self.config.dp_varying_mask_ratio:
206
- # Enable different random seeds for each DP rank during sampling
207
- import torch.distributed as dist
208
- dp_rank = 0
209
- if dist.is_initialized():
210
- try:
211
- dp_rank = dist.get_rank()
212
- except Exception:
213
- dp_rank = 0
214
- # Use a local generator to avoid affecting global RNG state
215
- generator = torch.Generator(device=device)
216
- generator.manual_seed(torch.seed() + dp_rank)
217
- else:
218
- generator = None
219
-
220
- t = torch.rand(b, device=device, generator=generator)
221
-
222
- p_mask = (1 - eps) * t + eps # shape: (b,)
223
- p_mask = p_mask[:, None].expand(-1, l) # shape: (b, l)
224
-
225
- masked_indices = torch.rand((b, l), device=device) < p_mask
226
-
227
- if loss_mask is not None:
228
- masked_indices[loss_mask == 0] = 0
229
-
230
- noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
231
-
232
- return noisy_batch, masked_indices, p_mask
233
-
234
-
235
- def forward(
236
- self,
237
- input_ids: torch.LongTensor,
238
- attention_mask: Optional[torch.Tensor] = None,
239
- position_ids: Optional[torch.LongTensor] = None,
240
- labels: Optional[torch.LongTensor] = None,
241
- split_len: Optional[int] = None,
242
- past_key_values: Optional[Cache] = None,
243
- block_size: Optional[int] = None,
244
- eps: float = 1e-3,
245
- is_teacher: bool = False,
246
- masked_indices: Optional[torch.Tensor] = None,
247
- p_mask: Optional[torch.Tensor] = None,
248
- teacher_logits: Optional[torch.Tensor] = None,
249
- masked_indices_teacher: Optional[torch.Tensor] = None,
250
- loss_mask: Optional[torch.Tensor] = None,
251
- ce_loss_weight: float = 1.0,
252
- output_last_hidden_states_only: bool = False,
253
- skip_loss: bool = False,
254
- **kwargs,
255
- ) -> CausalLMOutputWithPast:
256
-
257
- batch_size, seq_len = input_ids.shape
258
-
259
- if self.config.dlm_paradigm == 'block_diff':
260
- if labels is not None and block_size is None:
261
- block_size = self.config.block_size
262
- elif self.config.dlm_paradigm not in ('bidirectional', 'autoregressive'):
263
- raise ValueError(f"Unknown dLM paradigm: {self.config.dlm_paradigm}")
264
-
265
- if labels is not None and self.config.dlm_paradigm != 'autoregressive':
266
- if masked_indices is not None:
267
- # assert p_mask is not None
268
-
269
- if loss_mask is not None:
270
- masked_indices[loss_mask == 0] = 0
271
-
272
- noisy_inputs = torch.where(masked_indices, self.mask_token_id, input_ids)
273
-
274
- else:
275
- noisy_inputs, masked_indices, p_mask = self.forward_process(input_ids, eps=eps, block_size=block_size, loss_mask=loss_mask)
276
-
277
- else:
278
- noisy_inputs = input_ids
279
- masked_indices = None
280
- p_mask = None
281
-
282
- input_ids_len = noisy_inputs.shape[1]
283
- if labels is not None and self.config.dlm_paradigm == 'block_diff':
284
- if position_ids is None:
285
- position_ids = torch.arange(input_ids_len, device=noisy_inputs.device).unsqueeze(0)
286
- noisy_inputs = torch.cat([noisy_inputs, input_ids], dim=1)
287
-
288
- enc_out = self.encoder(
289
- past_key_values=past_key_values,
290
- input_ids=noisy_inputs,
291
- attention_mask=attention_mask,
292
- position_ids=position_ids,
293
- is_training=(labels is not None),
294
- **kwargs,
295
- )
296
-
297
- if output_last_hidden_states_only:
298
- return BaseModelOutput(last_hidden_state=enc_out.last_hidden_state)
299
-
300
- logits = self.diffusion_head(enc_out.last_hidden_state) # (batch, len_B, vocab)
301
- causal_logits = None
302
-
303
- if labels is not None and self.config.dlm_paradigm == 'block_diff':
304
- causal_logits = logits[:, input_ids_len:]
305
- logits = logits[:, :input_ids_len]
306
-
307
- loss = None
308
- if labels is not None and not skip_loss:
309
- if self.config.dlm_paradigm == 'autoregressive':
310
- shift_logits = logits[..., :-1, :].contiguous()
311
- shift_labels = labels[..., 1:].contiguous()
312
-
313
- if loss_mask is None:
314
- loss_fct = CrossEntropyLoss()
315
- shift_logits = shift_logits.view(-1, shift_logits.size(-1))
316
- shift_labels = shift_labels.view(-1)
317
- loss = loss_fct(shift_logits, shift_labels)
318
-
319
- else:
320
- loss_mask = loss_mask[..., 1:].contiguous()
321
-
322
- loss_fct = CrossEntropyLoss(reduction='none')
323
- shift_logits = shift_logits.view(-1, shift_logits.size(-1))
324
- shift_labels = shift_labels.view(-1)
325
- shift_labels = shift_labels.to(shift_logits.device)
326
-
327
- token_losses = loss_fct(shift_logits, shift_labels)
328
-
329
- flat_loss_mask = loss_mask.reshape(-1)
330
- loss = token_losses[flat_loss_mask == 1].sum() / flat_loss_mask.sum()
331
-
332
- else:
333
- # LLaDA-style diffusion loss on masked positions.
334
- # Token-wise cross entropy loss on masked positions.
335
- token_loss = torch.nn.functional.cross_entropy(
336
- logits[masked_indices],
337
- labels[masked_indices],
338
- reduction='none'
339
- ) / p_mask[masked_indices]
340
-
341
- num_mask_tokens = masked_indices.sum()
342
-
343
- # global_loss_avg=True: loss is reduced externally by global token count.
344
- loss = token_loss.sum()
345
-
346
- if self.config.dlm_loss_weight is not None:
347
- loss = self.config.dlm_loss_weight * loss
348
-
349
- if self.config.dlm_paradigm == 'block_diff':
350
- # AR-side loss for block-diffusion paradigm.
351
- causal_logits = causal_logits[..., :-1, :].contiguous()
352
- causal_logits = causal_logits.view(-1, causal_logits.size(-1))
353
- causal_labels = labels[..., 1:].contiguous().view(-1)
354
-
355
- loss_fct = CrossEntropyLoss(reduction='sum')
356
- ar_loss = loss_fct(causal_logits, causal_labels)
357
-
358
- self.loss_diffusion = loss.detach().item() / num_mask_tokens
359
- self.loss_ar = ar_loss.detach().item() / seq_len
360
-
361
- loss = loss + self.config.ar_loss_weight * ar_loss
362
-
363
- # global_loss_avg=True: return (sum_loss, token_count) for external mean.
364
- if self.config.dlm_paradigm == 'block_diff':
365
- loss = (loss, num_mask_tokens + int(self.config.ar_loss_weight * seq_len))
366
- else:
367
- loss = (loss, num_mask_tokens)
368
-
369
- return NemotronLabsDiffusionOutputWithPast(
370
- loss=loss if not is_teacher else logits,
371
- logits=logits,
372
- causal_logits=causal_logits,
373
- past_key_values=enc_out.past_key_values,
374
- hidden_states=None,
375
- attentions=None,
376
- )
377
-
378
-
379
- @torch.no_grad()
380
- def generate(
381
- self,
382
- prompt_ids: torch.Tensor,
383
- max_new_tokens: int,
384
- block_length: int,
385
- threshold: Optional[float] = None,
386
- causal_context: bool = True,
387
- temperature: float = 0.0,
388
- eos_token_id: Optional[int] = None,
389
- max_thinking_tokens: Optional[int] = None,
390
- end_think_token_id: Optional[int] = None,
391
- ):
392
- """Block-wise diffusion decoding with prefix-cached KV (LLaDA-style).
393
-
394
- Each block: append `block_length` mask tokens, then iteratively unmask
395
- by confidence top-k (with optional threshold). When `causal_context`,
396
- the KV cache and the next-block seed are produced via a causal forward
397
- between blocks (flipping `self_attn.diffusion_lm`), matching the AR
398
- objective at block boundaries.
399
-
400
- Returns (output_ids, nfe) — output_ids includes the prompt.
401
- """
402
- if eos_token_id is None:
403
- eos_token_id = getattr(self.config, "eos_token_id", None)
404
- mask_id = self.mask_token_id
405
-
406
- x_accum = prompt_ids.clone()
407
- B = prompt_ids.shape[0]
408
-
409
- assert max_new_tokens % block_length == 0
410
- num_blocks = max_new_tokens // block_length
411
- # one denoising step per generated token (matches legacy chat_utils call)
412
- steps_per_block = block_length
413
-
414
- nfe = 0
415
-
416
- def _set_diffusion_lm(val: bool):
417
- for layer in self.encoder.layers:
418
- if hasattr(layer.self_attn, "diffusion_lm"):
419
- layer.self_attn.diffusion_lm = val
420
-
421
- # Initial causal prefill produces the KV cache and the next-block seed.
422
- if causal_context:
423
- _set_diffusion_lm(False)
424
- output = self(prompt_ids, use_cache=True, use_causal_mask=causal_context)
425
- past_key_values = output.past_key_values
426
- if causal_context:
427
- _set_diffusion_lm(True)
428
-
429
- next_token = None
430
- if causal_context:
431
- last_logit = output.logits[:, -1, :]
432
- if temperature > 0:
433
- next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), num_samples=1)
434
- else:
435
- next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
436
-
437
- for num_block in range(num_blocks):
438
- mask_block = torch.full(
439
- (B, block_length), mask_id, dtype=prompt_ids.dtype, device=prompt_ids.device,
440
- )
441
- if causal_context:
442
- mask_block[:, 0] = next_token[:, 0]
443
-
444
- x_accum = torch.cat([x_accum, mask_block], dim=1)
445
- block_start = prompt_ids.size(1) + num_block * block_length
446
- block_slice = slice(block_start, block_start + block_length)
447
-
448
- # Thinking-budget enforcement: if we've passed max_thinking_tokens
449
- # without an end-think marker, inject one into this block.
450
- if end_think_token_id is not None and max_thinking_tokens is not None:
451
- tokens_before = num_block * block_length
452
- tokens_after = tokens_before + block_length
453
- if tokens_after > max_thinking_tokens:
454
- gen_so_far = x_accum[:, prompt_ids.size(1):block_start]
455
- has_end_think = (
456
- (gen_so_far == end_think_token_id).any(dim=1)
457
- if gen_so_far.size(1) > 0
458
- else torch.zeros(B, dtype=torch.bool, device=prompt_ids.device)
459
- )
460
- if not has_end_think.all():
461
- offset = max(0, max_thinking_tokens - tokens_before)
462
- inject_pos = block_start + offset
463
- for b in range(B):
464
- if not has_end_think[b]:
465
- x_accum[b, inject_pos] = end_think_token_id
466
-
467
- mask_block_idx0 = x_accum[:, block_slice] == mask_id
468
- num_transfer_tokens = _get_num_transfer_tokens(mask_block_idx0, steps_per_block)
469
-
470
- # Denoise the current block by repeated confidence-based unmasking.
471
- for i in range(steps_per_block):
472
- mask_block_idx = x_accum[:, block_slice] == mask_id
473
- if mask_block_idx.sum() == 0:
474
- break
475
-
476
- nfe += 1
477
- logits_block = self(
478
- x_accum[:, block_slice],
479
- past_key_values=past_key_values,
480
- use_cache=False,
481
- ).logits
482
-
483
- x0, transfer_idx = _get_transfer_index(
484
- logits_block, temperature, mask_block_idx, x_accum[:, block_slice],
485
- num_transfer_tokens=num_transfer_tokens[:, i], threshold=threshold,
486
- )
487
- cur = x_accum[:, block_slice].clone()
488
- cur[transfer_idx] = x0[transfer_idx]
489
- x_accum[:, block_slice] = cur
490
-
491
- if eos_token_id is not None:
492
- block_tokens = x_accum[:, block_slice]
493
- eos_mask = block_tokens == eos_token_id
494
- if eos_mask.any(dim=1).any():
495
- after_eos = eos_mask.cumsum(dim=1).bool()
496
- mask_before = (block_tokens == mask_id) & ~after_eos
497
- if (eos_mask.any(dim=1) & ~mask_before.any(dim=1)).any():
498
- break
499
-
500
- # Post-block: causal forward over the block to update the KV cache
501
- # and (when causal_context) sample the seed for the next block.
502
- if causal_context:
503
- _set_diffusion_lm(False)
504
- output = self(
505
- x_accum[:, block_slice],
506
- past_key_values=past_key_values,
507
- use_cache=True,
508
- use_causal_mask=causal_context,
509
- )
510
- past_key_values = output.past_key_values
511
- nfe += 1
512
-
513
- if causal_context:
514
- _set_diffusion_lm(True)
515
- last_logit = output.logits[:, -1, :]
516
- if temperature > 0:
517
- next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), num_samples=1)
518
- else:
519
- next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
520
-
521
- if eos_token_id is not None:
522
- gen_so_far = x_accum[:, prompt_ids.size(1):]
523
- is_eos = gen_so_far == eos_token_id
524
- if is_eos.any(dim=1).all():
525
- first_eos = is_eos.to(torch.int64).argmax(dim=1)
526
- max_eos = first_eos.max().item()
527
- return x_accum[:, : prompt_ids.size(1) + max_eos + 1], nfe
528
-
529
- return x_accum, nfe
530
-
531
-
532
-
533
- @torch.no_grad()
534
- def ar_generate(
535
- self,
536
- prompt_ids: torch.Tensor,
537
- max_new_tokens: int = 128,
538
- temperature: float = 0.0,
539
- eos_token_id: Optional[int] = None,
540
- max_thinking_tokens: Optional[int] = None,
541
- end_think_token_id: Optional[int] = None,
542
- ) -> tuple:
543
- """Autoregressive generation calling the encoder directly (injected by build_hf_tidar_repo).
544
-
545
- Bypasses NemotronLabsDiffusionModel.forward() to avoid diffusion-specific
546
- code paths. Calls self.encoder (Ministral3Model) with explicit cache_position,
547
- position_ids, and use_cache so the KV cache and causal masking behave
548
- identically to MistralForCausalLM / vLLM.
549
-
550
- Returns:
551
- (output_ids, nfe) where output_ids includes the prompt.
552
- """
553
- for layer in self.encoder.layers:
554
- if hasattr(layer.self_attn, 'diffusion_lm'):
555
- layer.self_attn.diffusion_lm = False
556
-
557
- if eos_token_id is None:
558
- eos_token_id = getattr(self.config, 'eos_token_id', None)
559
-
560
- device = prompt_ids.device
561
- batch_size, prompt_len = prompt_ids.shape
562
-
563
- past_key_values = DynamicCache()
564
- cache_position = torch.arange(prompt_len, device=device)
565
- position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
566
-
567
- enc_out = self.encoder(
568
- input_ids=prompt_ids,
569
- position_ids=position_ids,
570
- past_key_values=past_key_values,
571
- use_cache=True,
572
- cache_position=cache_position,
573
- )
574
- past_key_values = enc_out.past_key_values
575
- next_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
576
-
577
- generated_tokens = []
578
- nfe = 0
579
-
580
- for step in range(max_new_tokens):
581
- nfe += 1
582
-
583
- if temperature > 0:
584
- probs = torch.softmax(next_logit / temperature, dim=-1)
585
- next_token = torch.multinomial(probs, num_samples=1)
586
- else:
587
- next_token = torch.argmax(next_logit, dim=-1, keepdim=True)
588
-
589
- # ---- thinking budget enforcement ----
590
- if end_think_token_id is not None and max_thinking_tokens is not None:
591
- if step >= max_thinking_tokens:
592
- if generated_tokens:
593
- gen_tensor = torch.cat(generated_tokens, dim=1)
594
- has_end_think = (gen_tensor == end_think_token_id).any(dim=1)
595
- else:
596
- has_end_think = torch.zeros(batch_size, dtype=torch.bool, device=device)
597
- for b in range(batch_size):
598
- if not has_end_think[b]:
599
- next_token[b] = end_think_token_id
600
-
601
- generated_tokens.append(next_token)
602
-
603
- if eos_token_id is not None and (next_token == eos_token_id).all():
604
- break
605
-
606
- if step < max_new_tokens - 1:
607
- cur_pos = prompt_len + step
608
- step_cache_pos = torch.tensor([cur_pos], device=device)
609
- step_pos_ids = step_cache_pos.unsqueeze(0).expand(batch_size, -1)
610
-
611
- enc_out = self.encoder(
612
- input_ids=next_token,
613
- position_ids=step_pos_ids,
614
- past_key_values=past_key_values,
615
- use_cache=True,
616
- cache_position=step_cache_pos,
617
- )
618
- past_key_values = enc_out.past_key_values
619
- next_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
620
-
621
- all_generated = torch.cat(generated_tokens, dim=1)
622
- output_ids = torch.cat([prompt_ids, all_generated], dim=1)
623
- return output_ids, nfe
624
-
625
-
626
- @torch.no_grad()
627
- def linear_spec_generate(
628
- self,
629
- prompt_ids: torch.Tensor,
630
- max_new_tokens: int = 128,
631
- block_length: int = 32,
632
- temperature: float = 0.0,
633
- mask_token_id: Optional[int] = None,
634
- eos_token_id: Optional[int] = None,
635
- max_thinking_tokens: Optional[int] = None,
636
- end_think_token_id: Optional[int] = None,
637
- threshold: float = 0.0,
638
- ):
639
- """Linear speculative decoding: diffusion draft + AR verify.
640
-
641
- Each iteration: (1) draft the next block under bidirectional attention,
642
- (2) verify the drafted block under causal attention, accept the longest
643
- prefix where draft matches AR + one bonus token, advance the KV cache.
644
-
645
- LoRA-aware: when a PEFT adapter is attached to the model (e.g.
646
- ``linear_spec_lora``), it is toggled ON for the bidirectional draft
647
- phase and OFF for the causal prefill / verify phases — so the adapter
648
- only specializes the diffusion-mode forward and AR semantics are
649
- preserved. With no adapter loaded the calls are no-ops.
650
-
651
- Returns ``(output_ids, nfe)`` — ``output_ids`` includes the prompt.
652
- """
653
- if prompt_ids.shape[0] != 1:
654
- raise ValueError("Linear speculative decoding requires batch_size == 1")
655
-
656
- token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
657
- if eos_token_id is None:
658
- eos_token_id = getattr(self.config, "eos_token_id", None)
659
-
660
- device = prompt_ids.device
661
-
662
- def _set_diffusion_lm(val: bool):
663
- for layer in self.encoder.layers:
664
- if hasattr(layer.self_attn, "diffusion_lm"):
665
- layer.self_attn.diffusion_lm = val
666
-
667
- def _toggle_adapters(enable: bool):
668
- # No-op when no PEFT/LoRA modules are attached.
669
- for module in self.modules():
670
- if hasattr(module, "_disable_adapters"):
671
- module._disable_adapters = not enable
672
-
673
- # Prefill (causal, LoRA OFF).
674
- _set_diffusion_lm(False)
675
- _toggle_adapters(False)
676
- enc_out = self.encoder(
677
- input_ids=prompt_ids,
678
- past_key_values=DynamicCache(),
679
- use_cache=True,
680
- use_causal_mask=True,
681
- )
682
- past_key_values = enc_out.past_key_values
683
- last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
684
- nfe = 1
685
-
686
- if temperature > 0:
687
- next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), num_samples=1)
688
- else:
689
- next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
690
-
691
- if eos_token_id is not None and next_token.item() == eos_token_id:
692
- return torch.cat([prompt_ids, next_token], dim=1), nfe
693
-
694
- generated = [next_token]
695
- total_gen = 1
696
-
697
- while total_gen < max_new_tokens:
698
- cache_len = past_key_values.get_seq_length()
699
-
700
- block = torch.full((1, block_length), token_mask_id, dtype=torch.long, device=device)
701
- block[0, 0] = next_token.item()
702
-
703
- # Draft phase (bidirectional, LoRA ON) — iterate at threshold>0 so
704
- # that even low-confidence blocks make progress.
705
- _set_diffusion_lm(True)
706
- _toggle_adapters(True)
707
- while True:
708
- is_mask = block == token_mask_id
709
- if not is_mask.any():
710
- break
711
-
712
- enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=False)
713
- nfe += 1
714
-
715
- draft_logits = self.diffusion_head(enc_out.last_hidden_state)
716
- # LLaDA: logit[i] directly predicts position i — no shift needed.
717
-
718
- if temperature > 0:
719
- draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
720
- draft_tokens = torch.multinomial(
721
- draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1
722
- ).view(1, block_length)
723
- else:
724
- draft_tokens = draft_logits.argmax(dim=-1)
725
- draft_probs = torch.softmax(draft_logits, dim=-1)
726
-
727
- if threshold > 0:
728
- draft_conf = torch.gather(draft_probs, -1, draft_tokens.unsqueeze(-1)).squeeze(-1)
729
- draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
730
- unmask = draft_conf >= threshold
731
- # Force progress even when every masked position is below threshold.
732
- if not unmask.any():
733
- best_idx = draft_conf.view(-1).argmax()
734
- unmask = torch.zeros_like(is_mask, dtype=torch.bool)
735
- unmask.view(-1)[best_idx] = True
736
- block[unmask] = draft_tokens[unmask]
737
- else:
738
- block[is_mask] = draft_tokens[is_mask]
739
- break
740
-
741
- # Verify phase (causal, LoRA OFF).
742
- _set_diffusion_lm(False)
743
- _toggle_adapters(False)
744
- enc_out = self.encoder(
745
- input_ids=block,
746
- past_key_values=past_key_values,
747
- use_cache=True,
748
- use_causal_mask=True,
749
- )
750
- past_key_values = enc_out.past_key_values
751
- nfe += 1
752
-
753
- verify_logits = self.diffusion_head(enc_out.last_hidden_state)
754
- if temperature > 0:
755
- ar_tokens = torch.multinomial(
756
- torch.softmax(verify_logits / temperature, dim=-1).view(-1, verify_logits.shape[-1]),
757
- num_samples=1,
758
- ).view(1, block_length)
759
- else:
760
- ar_tokens = verify_logits.argmax(dim=-1)
761
-
762
- # Accept consecutive matches; AR also gives one bonus token at the tail.
763
- accepted = 0
764
- for i in range(block_length - 1):
765
- if ar_tokens[0, i].item() == block[0, i + 1].item():
766
- accepted += 1
767
- else:
768
- break
769
- accepted += 1
770
-
771
- accepted_toks = ar_tokens[:, :accepted]
772
- generated.append(accepted_toks)
773
- total_gen += accepted
774
-
775
- _crop_dynamic_cache(past_key_values, cache_len + accepted)
776
- next_token = ar_tokens[:, accepted - 1 : accepted]
777
-
778
- if eos_token_id is not None:
779
- eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
780
- if len(eos_pos) > 0:
781
- first_eos = eos_pos[0].item()
782
- generated[-1] = accepted_toks[:, : first_eos + 1]
783
- total_gen = total_gen - accepted + first_eos + 1
784
- break
785
-
786
- # Thinking-budget enforcement: force end-think as next seed if budget exhausted.
787
- if end_think_token_id is not None and max_thinking_tokens is not None:
788
- if total_gen > max_thinking_tokens:
789
- all_gen = torch.cat(generated, dim=1)
790
- if not (all_gen == end_think_token_id).any():
791
- next_token = torch.tensor([[end_think_token_id]], device=device)
792
-
793
- if total_gen >= max_new_tokens:
794
- break
795
-
796
- all_generated = torch.cat(generated, dim=1)
797
- output_ids = torch.cat([prompt_ids, all_generated], dim=1)
798
- return output_ids, nfe
799
-
800
-
801
- # ─── Module-level helpers used by `generate` and `linear_spec_generate` ──
802
-
803
- def _crop_dynamic_cache(past_key_values: DynamicCache, max_length: int):
804
- """Crop a DynamicCache to max_length, compatible with both old and new transformers."""
805
- if hasattr(past_key_values, 'crop'):
806
- past_key_values.crop(max_length)
807
- else:
808
- for layer_idx in range(len(past_key_values)):
809
- past_key_values.key_cache[layer_idx] = past_key_values.key_cache[layer_idx][:, :, :max_length]
810
- past_key_values.value_cache[layer_idx] = past_key_values.value_cache[layer_idx][:, :, :max_length]
811
- past_key_values._seen_tokens = max_length
812
-
813
-
814
- def _add_gumbel_noise(logits, temperature):
815
- """Gumbel-max sampling in float64 (low-precision Gumbel hurts MDM quality)."""
816
- if temperature == 0:
817
- return logits
818
- logits = logits.to(torch.float64)
819
- noise = torch.rand_like(logits, dtype=torch.float64)
820
- gumbel_noise = (- torch.log(noise)) ** temperature
821
- return logits.exp() / gumbel_noise
822
-
823
-
824
- def _get_num_transfer_tokens(mask_index, steps: int):
825
- """Even split of masked positions across `steps`, with remainder front-loaded."""
826
- mask_num = mask_index.sum(dim=1, keepdim=True)
827
- base = mask_num // steps
828
- remainder = mask_num % steps
829
- num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
830
- for i in range(mask_num.size(0)):
831
- num_transfer_tokens[i, : int(remainder[i])] += 1
832
- return num_transfer_tokens
833
-
834
-
835
- def _get_transfer_index(logits, temperature, mask_index, x, num_transfer_tokens, threshold=None):
836
- """Pick which masked positions to commit this denoising step.
837
-
838
- Returns (x0, transfer_index): x0 is argmax tokens (clamped to original x at
839
- non-masked positions); transfer_index is a bool mask over positions to
840
- finalize, chosen by top-k confidence (and filtered by `threshold` if given).
841
- """
842
- logits_with_noise = _add_gumbel_noise(logits, temperature=temperature)
843
- x0 = torch.argmax(logits_with_noise, dim=-1)
844
-
845
- p = F.softmax(logits, dim=-1)
846
- x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)
847
-
848
- x0 = torch.where(mask_index, x0, x)
849
- confidence = torch.where(mask_index, x0_p, -np.inf)
850
-
851
- transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
852
- if threshold is not None:
853
- num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
854
- for j in range(confidence.shape[0]):
855
- _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
856
- transfer_index[j, select_index] = True
857
- if threshold is not None:
858
- for k in range(1, num_transfer_tokens[j]):
859
- if confidence[j, select_index[k]] < threshold:
860
- transfer_index[j, select_index[k]] = False
861
- return x0, transfer_index
862
-
863
-
864
- def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
865
- """Return a Bool mask of length len(log_w) with exactly k True."""
866
- g = -torch.log(-torch.log(torch.rand_like(log_w) + 1e-9) + 1e-9)
867
- topk = torch.topk(log_w + g, k).indices
868
- mask = torch.zeros_like(log_w, dtype=torch.bool)
869
- mask[topk] = True
870
- return mask