Update chat_utils.py

#1
by lwhalen7 - opened
.gitattributes CHANGED
@@ -34,5 +34,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
- *.png filter=lfs diff=lfs merge=lfs -text
38
- *.pdf filter=lfs diff=lfs merge=lfs -text
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
README.md CHANGED
@@ -1,40 +1,29 @@
1
  ---
2
  library_name: transformers
3
- license: other
4
- license_name: cc-by-nc-4.0
5
- pipeline_tag: text-generation
6
  ---
7
 
8
- # Efficient-DLM-4B
9
 
10
- <p align="center">
11
- 📄 <a href="https://arxiv.org/pdf/2512.14067">Tech Report</a> &nbsp&nbsp|&nbsp&nbsp 🤗 <a href="https://huggingface.co/nvidia/Efficient-DLM-4B">Efficient-DLM-4B</a> &nbsp&nbsp|&nbsp&nbsp 🤗 <a href="https://huggingface.co/nvidia/Efficient-DLM-8B">Efficient-DLM-8B</a>
12
- </p>
13
 
14
 
15
- ## Model Overview
16
 
17
- Efficient-DLM-4B is a base diffusion language model designed for parallel generation. It converts pretrained AR LMs into diffusion LMs through efficient continuous pretraining, enabling faster decoding while preserving the task accuracy of strong AR models. Efficient-DLM features block-wise attention with clean-context conditioning for KV-cache-friendly decoding, as well as position-dependent token masking to reduce the training–test mismatch in diffusion generation. See our [paper](https://arxiv.org/abs/2512.14067) for more technical details.
18
 
19
- <div align="center">
20
- <img src="https://huggingface.co/nvidia/Efficient-DLM-4B/resolve/main/images/result.png" alt="Accuracy vs throughput Pareto curve" width="500">
21
- </div>
22
-
23
-
24
- ## Environment
25
-
26
- ```bash
27
- transformers>=4.52.2
28
  ```
29
 
 
30
 
31
- ## Chat with Efficient-DLM-4B
32
 
33
- ```python
34
  from transformers import AutoModel, AutoTokenizer
35
  import torch
36
 
37
- repo_name = "nvidia/Efficient-DLM-4B"
38
 
39
  tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
40
  model = AutoModel.from_pretrained(repo_name, trust_remote_code=True)
@@ -42,30 +31,10 @@ model = model.cuda().to(torch.bfloat16)
42
 
43
  user_input = input("User: ").strip()
44
 
45
- prompt_ids = tokenizer(user_input, return_tensors="pt").input_ids.to(device="cuda")
46
- out_ids, nfe = model.generate(
47
- prompt_ids,
48
- max_new_tokens=128,
49
- steps=128,
50
- block_length=32,
51
- shift_logits=False,
52
- temperature=0.7,
53
- threshold=0.9,
54
- )
55
 
56
- response = tokenizer.batch_decode(out_ids[:, prompt_ids.shape[1]:], skip_special_tokens=True)[0]
57
- print(f"Model: {response}")
58
  print(f"[Num Function Eval (NFE)={nfe}]")
59
- ```
60
-
61
-
62
- ## Citation
63
-
64
- ```
65
- @article{fu2025efficient,
66
- title={Efficient-dlm: From autoregressive to diffusion language models, and beyond in speed},
67
- author={Fu, Yonggan and Whalen, Lexington and Ye, Zhifan and Dong, Xin and Diao, Shizhe and Liu, Jingyu and Wu, Chengyue and Zhang, Hao and Xie, Enze and Han, Song and others},
68
- journal={arXiv preprint arXiv:2512.14067},
69
- year={2025}
70
- }
71
  ```
 
1
  ---
2
  library_name: transformers
3
+ tags: []
 
 
4
  ---
5
 
6
+ # Nemotron-Diffusion-Research-4B-v0
7
 
8
+ Developed by [DLER team](https://nv-dler.github.io/) @ NVR and will be updated actively. Contact Yonggan Fu and Pavlo Molchanov for any question.
 
 
9
 
10
 
11
+ # Environment
12
 
13
+ Docker path: `/lustre/fsw/portfolios/nvr/users/yongganf/docker/megatron_py25_dllm.sqsh` on OCI-ORD/OCI-NRT. Apply for interactive nodes with the following command:
14
 
15
+ ```
16
+ srun -A {account} --partition interactive --time 4:00:00 --gpus 8 --container-image /lustre/fsw/portfolios/nvr/users/yongganf/docker/megatron_py25_dllm.sqsh --container-mounts=$HOME:/home,/lustre:/lustre --pty bash
 
 
 
 
 
 
 
17
  ```
18
 
19
+ ## Chat with Our Model
20
 
 
21
 
22
+ ```
23
  from transformers import AutoModel, AutoTokenizer
24
  import torch
25
 
26
+ repo_name = "nvidia/Nemotron-Diffusion-Research-4B-v0"
27
 
28
  tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
29
  model = AutoModel.from_pretrained(repo_name, trust_remote_code=True)
 
31
 
32
  user_input = input("User: ").strip()
33
 
34
+ prompt_ids = tokenizer(user_input,return_tensors='pt').input_ids.to(device='cuda')
35
+ out_ids, nfe = model.generate(prompt_ids, max_new_tokens=128, steps=128, block_length=32, shift_logits=False, threshold=0.9)
 
 
 
 
 
 
 
 
36
 
37
+ tokenized_out = tokenizer.batch_decode(out_ids[:, prompt_ids.shape[1]:], skip_special_tokens=True)[0]
38
+ print(f"Model: {tokenized_out}")
39
  print(f"[Num Function Eval (NFE)={nfe}]")
 
 
 
 
 
 
 
 
 
 
 
 
40
  ```
chat_utils.py CHANGED
@@ -3,32 +3,20 @@ import torch
3
  import torch.nn.functional as F
4
 
5
 
6
- def add_gumbel_noise(logits, temperature):
7
- '''
8
- The Gumbel max is a method for sampling categorical distributions.
9
- According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
10
- Thus, we use float64.
11
- '''
12
- if temperature == 0:
13
- return logits
14
- logits = logits.to(torch.float64)
15
- noise = torch.rand_like(logits, dtype=torch.float64)
16
- gumbel_noise = (- torch.log(noise)) ** temperature
17
- return logits.exp() / gumbel_noise
18
 
19
-
20
- def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False):
21
- logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
22
- x0 = torch.argmax(logits_with_noise, dim=-1)
23
-
24
  if remasking == 'low_confidence':
25
  # p = F.softmax(logits.to(torch.float64), dim=-1)
26
- p = F.softmax(logits, dim=-1)
27
  x0_p = torch.squeeze(
28
  torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
29
  elif remasking == 'top_p_margin':
30
  # Compute probabilities
31
- p = F.softmax(logits, dim=-1) # (B, L, V)
32
  # Top-2 per position
33
  top2 = torch.topk(p, k=2, dim=-1).values # (B, L, 2)
34
  margin = top2[..., 0] - top2[..., 1] # (B, L)
@@ -64,7 +52,7 @@ def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transf
64
  # Calculate negative entropy if requested
65
  if neg_entropy:
66
  # p = F.softmax(logits.to(torch.float64), dim=-1)
67
- p = F.softmax(logits, dim=-1)
68
  epsilon = 1e-10
69
  log_probs = torch.log(p + epsilon)
70
  confidence_scores = torch.sum(p * log_probs, dim=-1) # negative entropy per position
@@ -216,7 +204,6 @@ def generate_with_prefix_cache_block_diff(
216
  use_cache=True
217
  )
218
  past_key_values = output.past_key_values
219
- nfe += 1
220
 
221
  if dream_style and num_block < num_blocks - 1:
222
  # refresh context-next logit for the next block
 
3
  import torch.nn.functional as F
4
 
5
 
6
+ def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None,neg_entropy=False):
7
+ x0 = torch.argmax(logits, dim=-1) # b, l
8
+
9
+ if temperature is None or temperature <= 0:
10
+ temperature = 1.0
 
 
 
 
 
 
 
11
 
 
 
 
 
 
12
  if remasking == 'low_confidence':
13
  # p = F.softmax(logits.to(torch.float64), dim=-1)
14
+ p = F.softmax(logits/temperature, dim=-1)
15
  x0_p = torch.squeeze(
16
  torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
17
  elif remasking == 'top_p_margin':
18
  # Compute probabilities
19
+ p = F.softmax(logits/temperature, dim=-1) # (B, L, V)
20
  # Top-2 per position
21
  top2 = torch.topk(p, k=2, dim=-1).values # (B, L, 2)
22
  margin = top2[..., 0] - top2[..., 1] # (B, L)
 
52
  # Calculate negative entropy if requested
53
  if neg_entropy:
54
  # p = F.softmax(logits.to(torch.float64), dim=-1)
55
+ p = F.softmax(logits/temperature, dim=-1)
56
  epsilon = 1e-10
57
  log_probs = torch.log(p + epsilon)
58
  confidence_scores = torch.sum(p * log_probs, dim=-1) # negative entropy per position
 
204
  use_cache=True
205
  )
206
  past_key_values = output.past_key_values
 
207
 
208
  if dream_style and num_block < num_blocks - 1:
209
  # refresh context-next logit for the next block
config.json CHANGED
@@ -1,14 +1,14 @@
1
  {
2
  "adaptive_mask_rate": false,
3
  "architectures": [
4
- "EfficientDLM"
5
  ],
6
  "attention_bias": false,
7
  "attention_dropout": 0.0,
8
  "attn_implementation": "sdpa",
9
  "auto_map": {
10
- "AutoConfig": "configuration_edlm.EfficientDLMConfig",
11
- "AutoModel": "modeling_edlm.EfficientDLM"
12
  },
13
  "block_size": 32,
14
  "diff_loss_weight": 1,
@@ -38,6 +38,7 @@
38
  "rms_norm_eps": 1e-06,
39
  "rope_scaling": null,
40
  "rope_theta": 1000000,
 
41
  "sliding_window": null,
42
  "tie_word_embeddings": false,
43
  "tok_mask_half_life_ratio": null,
 
1
  {
2
  "adaptive_mask_rate": false,
3
  "architectures": [
4
+ "DiffEncoderModel"
5
  ],
6
  "attention_bias": false,
7
  "attention_dropout": 0.0,
8
  "attn_implementation": "sdpa",
9
  "auto_map": {
10
+ "AutoConfig": "configuration_nvrdiff.NVRDiffConfig",
11
+ "AutoModel": "modeling_nvrdiff.DiffEncoderModel"
12
  },
13
  "block_size": 32,
14
  "diff_loss_weight": 1,
 
38
  "rms_norm_eps": 1e-06,
39
  "rope_scaling": null,
40
  "rope_theta": 1000000,
41
+ "seq_length": 1024,
42
  "sliding_window": null,
43
  "tie_word_embeddings": false,
44
  "tok_mask_half_life_ratio": null,
configuration_edlm.py → configuration_nvrdiff.py RENAMED
@@ -22,7 +22,7 @@ from transformers.utils import logging
22
  logger = logging.get_logger(__name__)
23
 
24
 
25
- class EfficientDLMConfig(PretrainedConfig):
26
  r"""
27
  This is the configuration class to store the configuration of a [`Qwen3Model`]. It is used to instantiate a
28
  Qwen3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
@@ -172,6 +172,7 @@ class EfficientDLMConfig(PretrainedConfig):
172
  max_window_layers=28,
173
  attention_dropout=0.0,
174
  attn_implementation="sdpa",
 
175
  mask_token_id=-1,
176
  dlm_type='llada',
177
  random_length_prob=None,
@@ -221,6 +222,7 @@ class EfficientDLMConfig(PretrainedConfig):
221
  rope_config_validation(self)
222
 
223
  self.attn_implementation = attn_implementation
 
224
 
225
  self.mask_token_id = mask_token_id
226
  self.dlm_type = dlm_type
@@ -245,4 +247,4 @@ class EfficientDLMConfig(PretrainedConfig):
245
  )
246
 
247
 
248
- __all__ = ["EfficientDLMConfig"]
 
22
  logger = logging.get_logger(__name__)
23
 
24
 
25
+ class NVRDiffConfig(PretrainedConfig):
26
  r"""
27
  This is the configuration class to store the configuration of a [`Qwen3Model`]. It is used to instantiate a
28
  Qwen3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
 
172
  max_window_layers=28,
173
  attention_dropout=0.0,
174
  attn_implementation="sdpa",
175
+ seq_length=1024,
176
  mask_token_id=-1,
177
  dlm_type='llada',
178
  random_length_prob=None,
 
222
  rope_config_validation(self)
223
 
224
  self.attn_implementation = attn_implementation
225
+ self.seq_length = seq_length
226
 
227
  self.mask_token_id = mask_token_id
228
  self.dlm_type = dlm_type
 
247
  )
248
 
249
 
250
+ __all__ = ["Qwen3Config"]
images/result.png → model-00001-of-00002.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9b81fe6641cd8816c4041697b0ac2cb1c4fcdfc2166504e2bde174c67ddc7eae
3
- size 221103
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42a85e2aa98cd482ece3ec213560fa67c1e15cbfa2a58c366e2c516887e50927
3
+ size 4967215816
model.safetensors → model-00002-of-00002.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:77c83e52654fd49874f6b09cf78b739da454c8320dd54c6970c3e5f88dc5e7c4
3
- size 8822895320
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcc2f6d41ac9fec18b6593d91efb2d1cd5abf7c76433d98887d65d8306b96523
3
+ size 3855679488
model.safetensors.index.json ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 8822848512
4
+ },
5
+ "weight_map": {
6
+ "diffusion_head.weight": "model-00002-of-00002.safetensors",
7
+ "encoder.embed_tokens.weight": "model-00001-of-00002.safetensors",
8
+ "encoder.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
9
+ "encoder.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
10
+ "encoder.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
11
+ "encoder.layers.0.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
12
+ "encoder.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
13
+ "encoder.layers.0.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
14
+ "encoder.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
15
+ "encoder.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
16
+ "encoder.layers.0.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
17
+ "encoder.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
18
+ "encoder.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
19
+ "encoder.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
20
+ "encoder.layers.1.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
21
+ "encoder.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
22
+ "encoder.layers.1.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
23
+ "encoder.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
24
+ "encoder.layers.1.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
25
+ "encoder.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
26
+ "encoder.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
27
+ "encoder.layers.1.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
28
+ "encoder.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
29
+ "encoder.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
30
+ "encoder.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
31
+ "encoder.layers.10.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
32
+ "encoder.layers.10.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
33
+ "encoder.layers.10.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
34
+ "encoder.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
35
+ "encoder.layers.10.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
36
+ "encoder.layers.10.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
37
+ "encoder.layers.10.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
38
+ "encoder.layers.10.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
39
+ "encoder.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
40
+ "encoder.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
41
+ "encoder.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
42
+ "encoder.layers.11.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
43
+ "encoder.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
44
+ "encoder.layers.11.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
45
+ "encoder.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
46
+ "encoder.layers.11.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
47
+ "encoder.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
48
+ "encoder.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
49
+ "encoder.layers.11.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
50
+ "encoder.layers.11.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
51
+ "encoder.layers.11.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
52
+ "encoder.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
53
+ "encoder.layers.12.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
54
+ "encoder.layers.12.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
55
+ "encoder.layers.12.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
56
+ "encoder.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
57
+ "encoder.layers.12.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
58
+ "encoder.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
59
+ "encoder.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
60
+ "encoder.layers.12.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
61
+ "encoder.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
62
+ "encoder.layers.12.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
63
+ "encoder.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors",
64
+ "encoder.layers.13.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
65
+ "encoder.layers.13.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
66
+ "encoder.layers.13.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
67
+ "encoder.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
68
+ "encoder.layers.13.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
69
+ "encoder.layers.13.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
70
+ "encoder.layers.13.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
71
+ "encoder.layers.13.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
72
+ "encoder.layers.13.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
73
+ "encoder.layers.13.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
74
+ "encoder.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
75
+ "encoder.layers.14.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
76
+ "encoder.layers.14.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
77
+ "encoder.layers.14.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
78
+ "encoder.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
79
+ "encoder.layers.14.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
80
+ "encoder.layers.14.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
81
+ "encoder.layers.14.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
82
+ "encoder.layers.14.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
83
+ "encoder.layers.14.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
84
+ "encoder.layers.14.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
85
+ "encoder.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors",
86
+ "encoder.layers.15.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
87
+ "encoder.layers.15.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
88
+ "encoder.layers.15.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
89
+ "encoder.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
90
+ "encoder.layers.15.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
91
+ "encoder.layers.15.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
92
+ "encoder.layers.15.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
93
+ "encoder.layers.15.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
94
+ "encoder.layers.15.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
95
+ "encoder.layers.15.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
96
+ "encoder.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors",
97
+ "encoder.layers.16.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
98
+ "encoder.layers.16.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
99
+ "encoder.layers.16.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
100
+ "encoder.layers.16.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
101
+ "encoder.layers.16.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
102
+ "encoder.layers.16.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
103
+ "encoder.layers.16.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
104
+ "encoder.layers.16.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
105
+ "encoder.layers.16.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
106
+ "encoder.layers.16.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
107
+ "encoder.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors",
108
+ "encoder.layers.17.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
109
+ "encoder.layers.17.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
110
+ "encoder.layers.17.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
111
+ "encoder.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
112
+ "encoder.layers.17.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
113
+ "encoder.layers.17.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
114
+ "encoder.layers.17.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
115
+ "encoder.layers.17.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
116
+ "encoder.layers.17.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
117
+ "encoder.layers.17.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
118
+ "encoder.layers.18.input_layernorm.weight": "model-00001-of-00002.safetensors",
119
+ "encoder.layers.18.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
120
+ "encoder.layers.18.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
121
+ "encoder.layers.18.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
122
+ "encoder.layers.18.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
123
+ "encoder.layers.18.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
124
+ "encoder.layers.18.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
125
+ "encoder.layers.18.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
126
+ "encoder.layers.18.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
127
+ "encoder.layers.18.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
128
+ "encoder.layers.18.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
129
+ "encoder.layers.19.input_layernorm.weight": "model-00001-of-00002.safetensors",
130
+ "encoder.layers.19.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
131
+ "encoder.layers.19.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
132
+ "encoder.layers.19.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
133
+ "encoder.layers.19.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
134
+ "encoder.layers.19.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
135
+ "encoder.layers.19.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
136
+ "encoder.layers.19.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
137
+ "encoder.layers.19.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
138
+ "encoder.layers.19.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
139
+ "encoder.layers.19.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
140
+ "encoder.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
141
+ "encoder.layers.2.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
142
+ "encoder.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
143
+ "encoder.layers.2.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
144
+ "encoder.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
145
+ "encoder.layers.2.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
146
+ "encoder.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
147
+ "encoder.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
148
+ "encoder.layers.2.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
149
+ "encoder.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
150
+ "encoder.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
151
+ "encoder.layers.20.input_layernorm.weight": "model-00002-of-00002.safetensors",
152
+ "encoder.layers.20.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
153
+ "encoder.layers.20.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
154
+ "encoder.layers.20.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
155
+ "encoder.layers.20.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
156
+ "encoder.layers.20.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
157
+ "encoder.layers.20.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
158
+ "encoder.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
159
+ "encoder.layers.20.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
160
+ "encoder.layers.20.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
161
+ "encoder.layers.20.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
162
+ "encoder.layers.21.input_layernorm.weight": "model-00002-of-00002.safetensors",
163
+ "encoder.layers.21.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
164
+ "encoder.layers.21.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
165
+ "encoder.layers.21.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
166
+ "encoder.layers.21.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
167
+ "encoder.layers.21.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
168
+ "encoder.layers.21.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
169
+ "encoder.layers.21.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
170
+ "encoder.layers.21.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
171
+ "encoder.layers.21.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
172
+ "encoder.layers.21.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
173
+ "encoder.layers.22.input_layernorm.weight": "model-00002-of-00002.safetensors",
174
+ "encoder.layers.22.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
175
+ "encoder.layers.22.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
176
+ "encoder.layers.22.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
177
+ "encoder.layers.22.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
178
+ "encoder.layers.22.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
179
+ "encoder.layers.22.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
180
+ "encoder.layers.22.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
181
+ "encoder.layers.22.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
182
+ "encoder.layers.22.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
183
+ "encoder.layers.22.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
184
+ "encoder.layers.23.input_layernorm.weight": "model-00002-of-00002.safetensors",
185
+ "encoder.layers.23.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
186
+ "encoder.layers.23.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
187
+ "encoder.layers.23.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
188
+ "encoder.layers.23.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
189
+ "encoder.layers.23.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
190
+ "encoder.layers.23.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
191
+ "encoder.layers.23.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
192
+ "encoder.layers.23.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
193
+ "encoder.layers.23.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
194
+ "encoder.layers.23.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
195
+ "encoder.layers.24.input_layernorm.weight": "model-00002-of-00002.safetensors",
196
+ "encoder.layers.24.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
197
+ "encoder.layers.24.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
198
+ "encoder.layers.24.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
199
+ "encoder.layers.24.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
200
+ "encoder.layers.24.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
201
+ "encoder.layers.24.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
202
+ "encoder.layers.24.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
203
+ "encoder.layers.24.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
204
+ "encoder.layers.24.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
205
+ "encoder.layers.24.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
206
+ "encoder.layers.25.input_layernorm.weight": "model-00002-of-00002.safetensors",
207
+ "encoder.layers.25.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
208
+ "encoder.layers.25.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
209
+ "encoder.layers.25.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
210
+ "encoder.layers.25.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
211
+ "encoder.layers.25.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
212
+ "encoder.layers.25.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
213
+ "encoder.layers.25.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
214
+ "encoder.layers.25.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
215
+ "encoder.layers.25.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
216
+ "encoder.layers.25.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
217
+ "encoder.layers.26.input_layernorm.weight": "model-00002-of-00002.safetensors",
218
+ "encoder.layers.26.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
219
+ "encoder.layers.26.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
220
+ "encoder.layers.26.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
221
+ "encoder.layers.26.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
222
+ "encoder.layers.26.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
223
+ "encoder.layers.26.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
224
+ "encoder.layers.26.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
225
+ "encoder.layers.26.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
226
+ "encoder.layers.26.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
227
+ "encoder.layers.26.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
228
+ "encoder.layers.27.input_layernorm.weight": "model-00002-of-00002.safetensors",
229
+ "encoder.layers.27.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
230
+ "encoder.layers.27.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
231
+ "encoder.layers.27.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
232
+ "encoder.layers.27.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
233
+ "encoder.layers.27.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
234
+ "encoder.layers.27.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
235
+ "encoder.layers.27.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
236
+ "encoder.layers.27.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
237
+ "encoder.layers.27.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
238
+ "encoder.layers.27.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
239
+ "encoder.layers.28.input_layernorm.weight": "model-00002-of-00002.safetensors",
240
+ "encoder.layers.28.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
241
+ "encoder.layers.28.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
242
+ "encoder.layers.28.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
243
+ "encoder.layers.28.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
244
+ "encoder.layers.28.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
245
+ "encoder.layers.28.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
246
+ "encoder.layers.28.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
247
+ "encoder.layers.28.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
248
+ "encoder.layers.28.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
249
+ "encoder.layers.28.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
250
+ "encoder.layers.29.input_layernorm.weight": "model-00002-of-00002.safetensors",
251
+ "encoder.layers.29.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
252
+ "encoder.layers.29.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
253
+ "encoder.layers.29.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
254
+ "encoder.layers.29.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
255
+ "encoder.layers.29.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
256
+ "encoder.layers.29.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
257
+ "encoder.layers.29.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
258
+ "encoder.layers.29.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
259
+ "encoder.layers.29.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
260
+ "encoder.layers.29.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
261
+ "encoder.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
262
+ "encoder.layers.3.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
263
+ "encoder.layers.3.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
264
+ "encoder.layers.3.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
265
+ "encoder.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
266
+ "encoder.layers.3.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
267
+ "encoder.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
268
+ "encoder.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
269
+ "encoder.layers.3.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
270
+ "encoder.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
271
+ "encoder.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
272
+ "encoder.layers.30.input_layernorm.weight": "model-00002-of-00002.safetensors",
273
+ "encoder.layers.30.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
274
+ "encoder.layers.30.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
275
+ "encoder.layers.30.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
276
+ "encoder.layers.30.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
277
+ "encoder.layers.30.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
278
+ "encoder.layers.30.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
279
+ "encoder.layers.30.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
280
+ "encoder.layers.30.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
281
+ "encoder.layers.30.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
282
+ "encoder.layers.30.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
283
+ "encoder.layers.31.input_layernorm.weight": "model-00002-of-00002.safetensors",
284
+ "encoder.layers.31.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
285
+ "encoder.layers.31.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
286
+ "encoder.layers.31.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
287
+ "encoder.layers.31.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
288
+ "encoder.layers.31.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
289
+ "encoder.layers.31.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
290
+ "encoder.layers.31.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
291
+ "encoder.layers.31.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
292
+ "encoder.layers.31.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
293
+ "encoder.layers.31.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
294
+ "encoder.layers.32.input_layernorm.weight": "model-00002-of-00002.safetensors",
295
+ "encoder.layers.32.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
296
+ "encoder.layers.32.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
297
+ "encoder.layers.32.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
298
+ "encoder.layers.32.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
299
+ "encoder.layers.32.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
300
+ "encoder.layers.32.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
301
+ "encoder.layers.32.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
302
+ "encoder.layers.32.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
303
+ "encoder.layers.32.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
304
+ "encoder.layers.32.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
305
+ "encoder.layers.33.input_layernorm.weight": "model-00002-of-00002.safetensors",
306
+ "encoder.layers.33.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
307
+ "encoder.layers.33.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
308
+ "encoder.layers.33.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
309
+ "encoder.layers.33.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
310
+ "encoder.layers.33.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
311
+ "encoder.layers.33.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
312
+ "encoder.layers.33.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
313
+ "encoder.layers.33.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
314
+ "encoder.layers.33.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
315
+ "encoder.layers.33.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
316
+ "encoder.layers.34.input_layernorm.weight": "model-00002-of-00002.safetensors",
317
+ "encoder.layers.34.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
318
+ "encoder.layers.34.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
319
+ "encoder.layers.34.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
320
+ "encoder.layers.34.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
321
+ "encoder.layers.34.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
322
+ "encoder.layers.34.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
323
+ "encoder.layers.34.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
324
+ "encoder.layers.34.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
325
+ "encoder.layers.34.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
326
+ "encoder.layers.34.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
327
+ "encoder.layers.35.input_layernorm.weight": "model-00002-of-00002.safetensors",
328
+ "encoder.layers.35.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
329
+ "encoder.layers.35.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
330
+ "encoder.layers.35.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
331
+ "encoder.layers.35.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
332
+ "encoder.layers.35.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
333
+ "encoder.layers.35.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
334
+ "encoder.layers.35.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
335
+ "encoder.layers.35.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
336
+ "encoder.layers.35.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
337
+ "encoder.layers.35.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
338
+ "encoder.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
339
+ "encoder.layers.4.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
340
+ "encoder.layers.4.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
341
+ "encoder.layers.4.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
342
+ "encoder.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
343
+ "encoder.layers.4.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
344
+ "encoder.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
345
+ "encoder.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
346
+ "encoder.layers.4.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
347
+ "encoder.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
348
+ "encoder.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
349
+ "encoder.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors",
350
+ "encoder.layers.5.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
351
+ "encoder.layers.5.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
352
+ "encoder.layers.5.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
353
+ "encoder.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
354
+ "encoder.layers.5.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
355
+ "encoder.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
356
+ "encoder.layers.5.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
357
+ "encoder.layers.5.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
358
+ "encoder.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
359
+ "encoder.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
360
+ "encoder.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
361
+ "encoder.layers.6.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
362
+ "encoder.layers.6.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
363
+ "encoder.layers.6.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
364
+ "encoder.layers.6.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
365
+ "encoder.layers.6.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
366
+ "encoder.layers.6.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
367
+ "encoder.layers.6.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
368
+ "encoder.layers.6.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
369
+ "encoder.layers.6.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
370
+ "encoder.layers.6.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
371
+ "encoder.layers.7.input_layernorm.weight": "model-00001-of-00002.safetensors",
372
+ "encoder.layers.7.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
373
+ "encoder.layers.7.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
374
+ "encoder.layers.7.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
375
+ "encoder.layers.7.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
376
+ "encoder.layers.7.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
377
+ "encoder.layers.7.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
378
+ "encoder.layers.7.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
379
+ "encoder.layers.7.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
380
+ "encoder.layers.7.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
381
+ "encoder.layers.7.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
382
+ "encoder.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
383
+ "encoder.layers.8.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
384
+ "encoder.layers.8.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
385
+ "encoder.layers.8.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
386
+ "encoder.layers.8.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
387
+ "encoder.layers.8.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
388
+ "encoder.layers.8.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
389
+ "encoder.layers.8.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
390
+ "encoder.layers.8.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
391
+ "encoder.layers.8.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
392
+ "encoder.layers.8.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
393
+ "encoder.layers.9.input_layernorm.weight": "model-00001-of-00002.safetensors",
394
+ "encoder.layers.9.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
395
+ "encoder.layers.9.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
396
+ "encoder.layers.9.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
397
+ "encoder.layers.9.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
398
+ "encoder.layers.9.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
399
+ "encoder.layers.9.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
400
+ "encoder.layers.9.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
401
+ "encoder.layers.9.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
402
+ "encoder.layers.9.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
403
+ "encoder.layers.9.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
404
+ "encoder.norm.weight": "model-00002-of-00002.safetensors"
405
+ }
406
+ }
modeling_edlm.py → modeling_nvrdiff.py RENAMED
@@ -22,7 +22,7 @@ from transformers.generation import GenerationMixin
22
  import math
23
 
24
  from .modeling_qwen3 import Qwen3Model, Qwen3PreTrainedModel, Qwen3Attention, apply_rotary_pos_emb, repeat_kv
25
- from .configuration_edlm import EfficientDLMConfig
26
  from .chat_utils import generate_with_prefix_cache_block_diff
27
 
28
  # @torch.compile(dynamic=True, mode="reduce-overhead")
@@ -37,32 +37,46 @@ class Qwen3FlexAttention(Qwen3Attention):
37
  def __init__(self, *args, **kwargs):
38
  super().__init__(*args, **kwargs)
39
 
40
- self.block_size = self.block_size_orig = self.config.block_size
 
 
41
 
42
- self.bidirectional_mask = None
43
  if self.config.dlm_paradigm == 'bidirectional':
44
  self.bidirectional_mask = self.compute_block_mask(mode='bidirectional')
 
 
 
 
45
  elif self.config.dlm_paradigm == 'block_diff':
46
- self.block_diff_mask = None
47
  else:
48
  raise ValueError(f"Unknown attention mode: {self.config.dlm_paradigm}")
49
 
 
 
50
  self.mode = 'bidirectional'
51
 
52
  import torch._dynamo.config as dcfg
53
  dcfg.cache_size_limit = 512
54
 
55
 
56
- def set_attention_mode(self, mode, block_size=None):
57
  self.mode = mode
 
58
  self.block_size = block_size
59
 
60
 
61
- def compute_block_mask(self, mode, q_len, block_size=None):
62
 
63
  def bidirectional_mask(b, h, q, kv):
64
  return (q >= kv) | (q < kv)
65
 
 
 
 
 
 
 
66
  def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
67
  """
68
  Constructs the specialized block diffusion attention mask for training
@@ -70,11 +84,13 @@ class Qwen3FlexAttention(Qwen3Attention):
70
  - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
71
  - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
72
  - **Block Causal Mask (M_BC)**: Attention to update x0
 
73
  Args:
74
  b, h: Batch and head indices (ignored for mask logic).
75
  q_idx, kv_idx: Query and Key indices.
76
  seq_len: Total sequence length.
77
  block_size: Defines the block structure.
 
78
  Returns:
79
  A boolean attention mask.
80
  """
@@ -109,14 +125,28 @@ class Qwen3FlexAttention(Qwen3Attention):
109
 
110
  if mode == 'bidirectional':
111
  attn_mask = bidirectional_mask
 
 
 
 
 
 
112
  elif mode == 'block_diff':
113
  assert block_size is not None
114
- attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, q_len//2)
115
  else:
116
  raise ValueError(f"Unknown attention mode: {mode}")
117
 
 
 
 
 
 
 
 
 
118
  block_mask = create_block_mask(
119
- attn_mask, B=None, H=None, Q_LEN=q_len, KV_LEN=q_len
120
  )
121
 
122
  return block_mask
@@ -166,12 +196,28 @@ class Qwen3FlexAttention(Qwen3Attention):
166
  value_states = repeat_kv(value_states, self.num_key_value_groups)
167
 
168
  if self.mode == 'bidirectional':
169
- if self.bidirectional_mask is None or q_len != self.bidirectional_mask.shape[-2]:
170
- block_mask = self.compute_block_mask(mode='bidirectional', q_len=q_len)
171
  else:
172
  block_mask = self.bidirectional_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  elif self.mode == 'block_diff':
174
- if self.block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.block_diff_mask.shape[-2]:
175
  block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
176
  else:
177
  block_mask = self.block_diff_mask
@@ -195,14 +241,14 @@ def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
195
  return mask
196
 
197
 
198
- class EfficientDLM(Qwen3PreTrainedModel, GenerationMixin):
199
  """
200
  A single model with:
201
  - a bidirectional encoder + diffusion‐LM head over A
202
  - a causal decoder + LM head over B, conditioned on F_A
203
  """
204
 
205
- def __init__(self, config: EfficientDLMConfig):
206
  super().__init__(config)
207
 
208
  self.mask_token_id = config.mask_token_id
@@ -210,7 +256,7 @@ class EfficientDLM(Qwen3PreTrainedModel, GenerationMixin):
210
  diffusion_config = copy.deepcopy(config)
211
  diffusion_config.diffusion_lm = True
212
 
213
- if config.dlm_paradigm in ['block_diff']:
214
  diffusion_config.attn_class = Qwen3FlexAttention
215
  elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
216
  diffusion_config.attn_class = Qwen3Attention
@@ -256,13 +302,16 @@ class EfficientDLM(Qwen3PreTrainedModel, GenerationMixin):
256
  ):
257
  """
258
  Two-stage corruption with optional per-block sampling.
 
259
  • Stage 1: m ~ U(eps, 1) → k = round(m · len) (exact budget).
260
  • Stage 2: sample exactly k positions with weights
261
  w_i(m) = exp[ λ · (1−m) · i ] (late-heavy when m→0,
262
  uniform when m→1).
 
263
  If `block_size` is given, the procedure is run *independently*
264
  inside each contiguous block of that length (last block may be shorter).
265
  When block_size is provided, m is sampled per-block and p_mask is per-block.
 
266
  Args
267
  ----
268
  input_ids : (B, L) LongTensor
@@ -350,73 +399,81 @@ class EfficientDLM(Qwen3PreTrainedModel, GenerationMixin):
350
  masked_indices: Optional[torch.Tensor] = None,
351
  p_mask: Optional[torch.Tensor] = None,
352
  loss_mask: Optional[torch.Tensor] = None,
353
- skip_loss: bool = False,
354
- inputs_embeds: Optional[torch.FloatTensor] = None,
355
  **kwargs,
356
  ) -> CausalLMOutputWithPast:
357
 
358
- if inputs_embeds is not None:
359
- noisy_inputs = None
360
- else:
361
- batch_size, seq_len = input_ids.shape
362
-
363
- if self.config.dlm_paradigm == 'bidirectional':
364
- if labels is not None and torch.rand(1) < self.config.random_length_prob:
365
- random_length = torch.randint(2, input_ids.shape[1] + 1, (1,))
366
- input_ids = input_ids[:, :random_length]
367
- labels = labels[:, :random_length]
368
-
369
- if attention_mask is not None:
370
- attention_mask = attention_mask[:, :random_length]
371
- if position_ids is not None:
372
- position_ids = position_ids[:, :random_length]
373
 
374
- elif self.config.dlm_paradigm == 'block_diff':
375
- if labels is not None and block_size is None:
376
- if torch.rand(1) < self.config.random_length_prob:
377
- block_size = torch.randint(1, 8, (1,)).item() * 4 ## [4, 32] divisible by 4
378
- else:
379
- block_size = self.config.block_size
 
 
 
 
 
 
 
 
 
 
 
380
 
381
- if labels is not None and self.config.dlm_paradigm != 'autoregressive':
382
- if masked_indices is not None:
383
- #assert p_mask is not None
 
 
 
384
 
385
- if loss_mask is not None:
386
- masked_indices[loss_mask == 0] = 0
 
 
 
 
387
 
388
- noisy_inputs = torch.where(masked_indices, self.mask_token_id, input_ids)
 
 
389
 
390
- else:
391
- if self.config.tok_mask_half_life_ratio is not None:
392
- noisy_inputs, masked_indices, p_mask = self.forward_process_exp(input_ids, eps=eps, block_size=block_size, half_life_ratio=self.config.tok_mask_half_life_ratio, loss_mask=loss_mask)
393
- else:
394
- noisy_inputs, masked_indices, p_mask = self.forward_process(input_ids, eps=eps, block_size=block_size, loss_mask=loss_mask)
395
 
396
  else:
397
- noisy_inputs = input_ids
398
- masked_indices = None
399
- p_mask = None
 
400
 
401
- if self.config.dlm_paradigm in ['block_diff']:
402
- for layer in self.encoder.layers:
403
- if hasattr(layer.self_attn, 'set_attention_mode'):
404
- layer.self_attn.set_attention_mode(self.config.dlm_paradigm, block_size=block_size)
 
 
 
 
 
405
 
406
- input_ids_len = noisy_inputs.shape[1]
407
- if labels is not None and self.config.dlm_paradigm == 'block_diff':
408
- if position_ids is None:
409
- position_ids = torch.arange(input_ids_len, device=noisy_inputs.device).unsqueeze(0)
410
- noisy_inputs = torch.cat([noisy_inputs, input_ids], dim=1)
411
 
412
- if block_diff_ppl:
413
- if position_ids is None:
414
- position_ids = torch.arange(input_ids_len // 2, device=noisy_inputs.device).unsqueeze(0)
415
 
416
  enc_out = self.encoder(
417
  past_key_values=past_key_values,
418
  input_ids=noisy_inputs,
419
- inputs_embeds=inputs_embeds,
420
  attention_mask=attention_mask,
421
  position_ids=position_ids,
422
  is_training=(labels is not None) or (block_diff_ppl),
@@ -429,56 +486,56 @@ class EfficientDLM(Qwen3PreTrainedModel, GenerationMixin):
429
  logits = logits[:, :input_ids_len]
430
 
431
  loss = None
432
- if labels is not None and not skip_loss:
433
- if self.config.dlm_paradigm == 'autoregressive':
434
- shift_logits = logits[..., :-1, :].contiguous()
435
- shift_labels = labels[..., 1:].contiguous()
436
 
437
- if loss_mask is None:
438
- loss_fct = CrossEntropyLoss()
439
- shift_logits = shift_logits.view(-1, shift_logits.size(-1))
440
- shift_labels = shift_labels.view(-1)
441
- loss = loss_fct(shift_logits, shift_labels)
442
-
443
- else:
444
- loss_mask = loss_mask[..., 1:].contiguous()
445
-
446
- loss_fct = CrossEntropyLoss(reduction='none')
447
- shift_logits = shift_logits.view(-1, shift_logits.size(-1))
448
- shift_labels = shift_labels.view(-1)
449
- shift_labels = shift_labels.to(shift_logits.device)
450
 
451
- token_losses = loss_fct(shift_logits, shift_labels)
452
 
453
- loss = token_losses[loss_mask].sum() / loss_mask.sum()
454
-
455
- else:
456
- # Handle DREAM vs LLADA style losses
457
- if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
458
- logits = logits[..., :-1, :].contiguous()
459
- labels = labels[..., 1:].contiguous()
460
- masked_indices = masked_indices[:, 1:]
461
- p_mask = p_mask[:, 1:]
462
-
463
- # Calculate token-wise cross entropy loss for masked positions in B
464
- token_loss = torch.nn.functional.cross_entropy(
465
- logits[masked_indices],
466
- labels[masked_indices],
467
- reduction='none'
468
- ) / p_mask[masked_indices]
469
 
470
- loss = token_loss.sum() / masked_indices.sum()
471
 
472
  return CausalLMOutputWithPast(
473
  loss=loss if not is_teacher else logits,
474
  logits=logits,
475
  past_key_values=enc_out.past_key_values,
476
- hidden_states=enc_out.last_hidden_state,
477
  attentions=None,
478
  )
479
 
480
 
481
- def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold, temperature=0):
482
  out_ids, nfe = generate_with_prefix_cache_block_diff(
483
  model=self,
484
  prompt=prompt_ids,
@@ -489,7 +546,6 @@ class EfficientDLM(Qwen3PreTrainedModel, GenerationMixin):
489
  mask_id=self.mask_token_id,
490
  threshold=threshold,
491
  shift_logits=shift_logits,
492
- temperature=temperature,
493
  neg_entropy=False,
494
  )
495
 
 
22
  import math
23
 
24
  from .modeling_qwen3 import Qwen3Model, Qwen3PreTrainedModel, Qwen3Attention, apply_rotary_pos_emb, repeat_kv
25
+ from .configuration_nvrdiff import NVRDiffConfig
26
  from .chat_utils import generate_with_prefix_cache_block_diff
27
 
28
  # @torch.compile(dynamic=True, mode="reduce-overhead")
 
37
  def __init__(self, *args, **kwargs):
38
  super().__init__(*args, **kwargs)
39
 
40
+ self.max_seq_length = self.config.seq_length
41
+ self.prefix_len_orig = int(self.config.seq_length * self.config.prefix_ratio)
42
+ self.block_size_orig = self.config.block_size
43
 
 
44
  if self.config.dlm_paradigm == 'bidirectional':
45
  self.bidirectional_mask = self.compute_block_mask(mode='bidirectional')
46
+ elif self.config.dlm_paradigm == 'prefix_bidirectional':
47
+ self.prefix_bidirectional_mask = self.compute_block_mask(mode='prefix_bidirectional', prefix_len=self.prefix_len_orig)
48
+ elif self.config.dlm_paradigm == 'efficient_block_diff':
49
+ self.efficient_block_diff_mask = self.compute_block_mask(mode='efficient_block_diff', block_size=self.block_size_orig)
50
  elif self.config.dlm_paradigm == 'block_diff':
51
+ self.block_diff_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size_orig)
52
  else:
53
  raise ValueError(f"Unknown attention mode: {self.config.dlm_paradigm}")
54
 
55
+ self.prefix_len = self.prefix_len_orig
56
+ self.block_size = self.block_size_orig
57
  self.mode = 'bidirectional'
58
 
59
  import torch._dynamo.config as dcfg
60
  dcfg.cache_size_limit = 512
61
 
62
 
63
+ def set_attention_mode(self, mode, prefix_len=None, block_size=None):
64
  self.mode = mode
65
+ self.prefix_len = prefix_len
66
  self.block_size = block_size
67
 
68
 
69
+ def compute_block_mask(self, mode, prefix_len=None, q_len=None, block_size=None):
70
 
71
  def bidirectional_mask(b, h, q, kv):
72
  return (q >= kv) | (q < kv)
73
 
74
+ def prefix_bidirectional_mask(prefix_len, b, h, q, kv):
75
+ return (kv <= prefix_len) | (q >= prefix_len)
76
+
77
+ def efficient_block_diff_mask(block_size, b, h, q, kv):
78
+ return (q // block_size) >= (kv // block_size)
79
+
80
  def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
81
  """
82
  Constructs the specialized block diffusion attention mask for training
 
84
  - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
85
  - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
86
  - **Block Causal Mask (M_BC)**: Attention to update x0
87
+
88
  Args:
89
  b, h: Batch and head indices (ignored for mask logic).
90
  q_idx, kv_idx: Query and Key indices.
91
  seq_len: Total sequence length.
92
  block_size: Defines the block structure.
93
+
94
  Returns:
95
  A boolean attention mask.
96
  """
 
125
 
126
  if mode == 'bidirectional':
127
  attn_mask = bidirectional_mask
128
+ elif mode == 'prefix_bidirectional':
129
+ assert prefix_len is not None
130
+ attn_mask = lambda b, h, q, kv: prefix_bidirectional_mask(prefix_len, b, h, q, kv)
131
+ elif mode == 'efficient_block_diff':
132
+ assert block_size is not None
133
+ attn_mask = lambda b, h, q, kv: efficient_block_diff_mask(block_size, b, h, q, kv)
134
  elif mode == 'block_diff':
135
  assert block_size is not None
136
+ attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, self.max_seq_length)
137
  else:
138
  raise ValueError(f"Unknown attention mode: {mode}")
139
 
140
+ if q_len is not None:
141
+ Q_LEN = q_len
142
+ else:
143
+ if mode == 'block_diff':
144
+ Q_LEN = self.max_seq_length * 2
145
+ else:
146
+ Q_LEN = self.max_seq_length
147
+
148
  block_mask = create_block_mask(
149
+ attn_mask, B=None, H=None, Q_LEN=Q_LEN, KV_LEN=Q_LEN
150
  )
151
 
152
  return block_mask
 
196
  value_states = repeat_kv(value_states, self.num_key_value_groups)
197
 
198
  if self.mode == 'bidirectional':
199
+ if q_len != self.bidirectional_mask.shape[-2]:
200
+ block_mask = self.compute_block_mask(mode='bidirectional', prefix_len=self.prefix_len, q_len=q_len)
201
  else:
202
  block_mask = self.bidirectional_mask
203
+
204
+ elif self.mode == 'prefix_bidirectional':
205
+ if self.prefix_len != self.prefix_len_orig or q_len != self.prefix_bidirectional_mask.shape[-2]:
206
+ block_mask = self.compute_block_mask(mode='prefix_bidirectional', prefix_len=self.prefix_len, q_len=q_len)
207
+
208
+ # print('create new block mask length for:',self.prefix_len)
209
+ # print(f"Block mask shape: {block_mask.shape}")
210
+ # print("Block mask pattern:")
211
+ # print(block_mask)
212
+ else:
213
+ block_mask = self.prefix_bidirectional_mask
214
+ elif self.mode == 'efficient_block_diff':
215
+ if self.block_size != self.block_size_orig or q_len != self.efficient_block_diff_mask.shape[-2]:
216
+ block_mask = self.compute_block_mask(mode='efficient_block_diff', block_size=self.block_size, q_len=q_len)
217
+ else:
218
+ block_mask = self.efficient_block_diff_mask
219
  elif self.mode == 'block_diff':
220
+ if self.block_size != self.block_size_orig or q_len != self.block_diff_mask.shape[-2]:
221
  block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
222
  else:
223
  block_mask = self.block_diff_mask
 
241
  return mask
242
 
243
 
244
+ class DiffEncoderModel(Qwen3PreTrainedModel, GenerationMixin):
245
  """
246
  A single model with:
247
  - a bidirectional encoder + diffusion‐LM head over A
248
  - a causal decoder + LM head over B, conditioned on F_A
249
  """
250
 
251
+ def __init__(self, config: NVRDiffConfig):
252
  super().__init__(config)
253
 
254
  self.mask_token_id = config.mask_token_id
 
256
  diffusion_config = copy.deepcopy(config)
257
  diffusion_config.diffusion_lm = True
258
 
259
+ if config.dlm_paradigm in ['prefix_bidirectional', 'efficient_block_diff', 'block_diff']:
260
  diffusion_config.attn_class = Qwen3FlexAttention
261
  elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
262
  diffusion_config.attn_class = Qwen3Attention
 
302
  ):
303
  """
304
  Two-stage corruption with optional per-block sampling.
305
+
306
  • Stage 1: m ~ U(eps, 1) → k = round(m · len) (exact budget).
307
  • Stage 2: sample exactly k positions with weights
308
  w_i(m) = exp[ λ · (1−m) · i ] (late-heavy when m→0,
309
  uniform when m→1).
310
+
311
  If `block_size` is given, the procedure is run *independently*
312
  inside each contiguous block of that length (last block may be shorter).
313
  When block_size is provided, m is sampled per-block and p_mask is per-block.
314
+
315
  Args
316
  ----
317
  input_ids : (B, L) LongTensor
 
399
  masked_indices: Optional[torch.Tensor] = None,
400
  p_mask: Optional[torch.Tensor] = None,
401
  loss_mask: Optional[torch.Tensor] = None,
 
 
402
  **kwargs,
403
  ) -> CausalLMOutputWithPast:
404
 
405
+ batch_size, seq_len = input_ids.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
 
407
+ if self.config.dlm_paradigm == 'bidirectional':
408
+ if labels is not None and torch.rand(1) < self.config.random_length_prob:
409
+ random_length = torch.randint(2, input_ids.shape[1] + 1, (1,))
410
+ input_ids = input_ids[:, :random_length]
411
+ labels = labels[:, :random_length]
412
+
413
+ if attention_mask is not None:
414
+ attention_mask = attention_mask[:, :random_length]
415
+ if position_ids is not None:
416
+ position_ids = position_ids[:, :random_length]
417
+
418
+ elif self.config.dlm_paradigm == 'prefix_bidirectional':
419
+ if labels is not None and split_len is None:
420
+ if torch.rand(1) < self.config.random_length_prob:
421
+ split_len = torch.randint(1, seq_len//64, (1,)).item() * 64 ## [64, seq_len] divisible by 64
422
+ else:
423
+ split_len = int(seq_len * self.config.prefix_ratio)
424
 
425
+ elif self.config.dlm_paradigm == 'efficient_block_diff':
426
+ if labels is not None and block_size is None:
427
+ if torch.rand(1) < self.config.random_length_prob:
428
+ block_size = torch.randint(1, 8, (1,)).item() * 4 ## [4, 32] divisible by 4
429
+ else:
430
+ block_size = self.config.block_size
431
 
432
+ elif self.config.dlm_paradigm == 'block_diff':
433
+ if labels is not None and block_size is None:
434
+ if torch.rand(1) < self.config.random_length_prob:
435
+ block_size = torch.randint(1, 8, (1,)).item() * 4 ## [4, 32] divisible by 4
436
+ else:
437
+ block_size = self.config.block_size
438
 
439
+ if labels is not None and self.config.dlm_paradigm != 'autoregressive':
440
+ if masked_indices is not None:
441
+ #assert p_mask is not None
442
 
443
+ if loss_mask is not None:
444
+ masked_indices[loss_mask == 0] = 0
445
+
446
+ noisy_inputs = torch.where(masked_indices, self.mask_token_id, input_ids)
 
447
 
448
  else:
449
+ if self.config.tok_mask_half_life_ratio is not None:
450
+ noisy_inputs, masked_indices, p_mask = self.forward_process_exp(input_ids, eps=eps, block_size=block_size, half_life_ratio=self.config.tok_mask_half_life_ratio, loss_mask=loss_mask)
451
+ else:
452
+ noisy_inputs, masked_indices, p_mask = self.forward_process(input_ids, eps=eps, block_size=block_size, loss_mask=loss_mask)
453
 
454
+ else:
455
+ noisy_inputs = input_ids
456
+ masked_indices = None
457
+ p_mask = None
458
+
459
+ if self.config.dlm_paradigm in ['prefix_bidirectional', 'efficient_block_diff', 'block_diff']:
460
+ for layer in self.encoder.layers:
461
+ if hasattr(layer.self_attn, 'set_attention_mode'):
462
+ layer.self_attn.set_attention_mode(self.config.dlm_paradigm, prefix_len=split_len, block_size=block_size)
463
 
464
+ input_ids_len = noisy_inputs.shape[1]
465
+ if labels is not None and self.config.dlm_paradigm == 'block_diff':
466
+ if position_ids is None:
467
+ position_ids = torch.arange(input_ids_len, device=noisy_inputs.device).unsqueeze(0)
468
+ noisy_inputs = torch.cat([noisy_inputs, input_ids], dim=1)
469
 
470
+ if block_diff_ppl:
471
+ if position_ids is None:
472
+ position_ids = torch.arange(input_ids_len // 2, device=noisy_inputs.device).unsqueeze(0)
473
 
474
  enc_out = self.encoder(
475
  past_key_values=past_key_values,
476
  input_ids=noisy_inputs,
 
477
  attention_mask=attention_mask,
478
  position_ids=position_ids,
479
  is_training=(labels is not None) or (block_diff_ppl),
 
486
  logits = logits[:, :input_ids_len]
487
 
488
  loss = None
489
+ # if labels is not None:
490
+ # if self.config.dlm_paradigm == 'autoregressive':
491
+ # shift_logits = logits[..., :-1, :].contiguous()
492
+ # shift_labels = labels[..., 1:].contiguous()
493
 
494
+ # if loss_mask is None:
495
+ # loss_fct = CrossEntropyLoss()
496
+ # shift_logits = shift_logits.view(-1, shift_logits.size(-1))
497
+ # shift_labels = shift_labels.view(-1)
498
+ # loss = loss_fct(shift_logits, shift_labels)
499
+
500
+ # else:
501
+ # loss_mask = loss_mask[..., 1:].contiguous()
502
+
503
+ # loss_fct = CrossEntropyLoss(reduction='none')
504
+ # shift_logits = shift_logits.view(-1, shift_logits.size(-1))
505
+ # shift_labels = shift_labels.view(-1)
506
+ # shift_labels = shift_labels.to(shift_logits.device)
507
 
508
+ # token_losses = loss_fct(shift_logits, shift_labels)
509
 
510
+ # loss = token_losses[loss_mask].sum() / loss_mask.sum()
511
+
512
+ # else:
513
+ # # Handle DREAM vs LLADA style losses
514
+ # if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
515
+ # logits = logits[..., :-1, :].contiguous()
516
+ # labels = labels[..., 1:].contiguous()
517
+ # masked_indices = masked_indices[:, 1:]
518
+ # p_mask = p_mask[:, 1:]
519
+
520
+ # # Calculate token-wise cross entropy loss for masked positions in B
521
+ # token_loss = torch.nn.functional.cross_entropy(
522
+ # logits[masked_indices],
523
+ # labels[masked_indices],
524
+ # reduction='none'
525
+ # ) / p_mask[masked_indices]
526
 
527
+ # loss = token_loss.sum() / masked_indices.sum()
528
 
529
  return CausalLMOutputWithPast(
530
  loss=loss if not is_teacher else logits,
531
  logits=logits,
532
  past_key_values=enc_out.past_key_values,
533
+ hidden_states=None,
534
  attentions=None,
535
  )
536
 
537
 
538
+ def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold):
539
  out_ids, nfe = generate_with_prefix_cache_block_diff(
540
  model=self,
541
  prompt=prompt_ids,
 
546
  mask_id=self.mask_token_id,
547
  threshold=threshold,
548
  shift_logits=shift_logits,
 
549
  neg_entropy=False,
550
  )
551
 
modeling_qwen3.py CHANGED
@@ -35,14 +35,8 @@ from transformers.modeling_outputs import (
35
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
36
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
  from transformers.processing_utils import Unpack
38
- from transformers.utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
39
- try:
40
- from transformers.utils import TransformersKwargs
41
- except ImportError:
42
- from typing import TypedDict
43
- class TransformersKwargs(TypedDict, total=False):
44
- pass
45
- from .configuration_edlm import EfficientDLMConfig
46
 
47
 
48
  if is_torch_flex_attn_available():
@@ -166,7 +160,7 @@ def eager_attention_forward(
166
  class Qwen3Attention(nn.Module):
167
  """Multi-headed attention from 'Attention Is All You Need' paper"""
168
 
169
- def __init__(self, config: EfficientDLMConfig, layer_idx: int):
170
  super().__init__()
171
  self.config = config
172
 
@@ -312,7 +306,7 @@ class Qwen3Attention(nn.Module):
312
 
313
 
314
  class Qwen3DecoderLayer(GradientCheckpointingLayer):
315
- def __init__(self, config: EfficientDLMConfig, layer_idx: int):
316
  super().__init__()
317
  self.hidden_size = config.hidden_size
318
  if hasattr(config, 'attn_class'):
@@ -383,7 +377,7 @@ class Qwen3DecoderLayer(GradientCheckpointingLayer):
383
 
384
  @auto_docstring
385
  class Qwen3PreTrainedModel(PreTrainedModel):
386
- config_class = EfficientDLMConfig
387
  base_model_prefix = "model"
388
  supports_gradient_checkpointing = True
389
  _no_split_modules = ["Qwen3DecoderLayer"]
@@ -411,7 +405,7 @@ class Qwen3PreTrainedModel(PreTrainedModel):
411
 
412
 
413
  class Qwen3RotaryEmbedding(nn.Module):
414
- def __init__(self, config: EfficientDLMConfig, device=None):
415
  super().__init__()
416
  # BC: "rope_type" was originally "type"
417
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
@@ -446,7 +440,7 @@ class Qwen3RotaryEmbedding(nn.Module):
446
 
447
  @auto_docstring
448
  class Qwen3Model(Qwen3PreTrainedModel):
449
- def __init__(self, config: EfficientDLMConfig):
450
  super().__init__(config)
451
  self.config = config
452
 
@@ -696,7 +690,7 @@ class Qwen3Model(Qwen3PreTrainedModel):
696
  dtype: torch.dtype,
697
  cache_position: torch.Tensor,
698
  batch_size: int,
699
- config: EfficientDLMConfig,
700
  past_key_values: Cache,
701
  ):
702
  """
@@ -716,7 +710,7 @@ class Qwen3Model(Qwen3PreTrainedModel):
716
  Indices depicting the position of the input sequence tokens in the sequence.
717
  batch_size (`torch.Tensor`):
718
  Batch size.
719
- config (`EfficientDLMConfig`):
720
  The model's configuration class
721
  past_key_values (`Cache`):
722
  The cache class that is being used currently to generate
 
35
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
36
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
  from transformers.processing_utils import Unpack
38
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
39
+ from .configuration_nvrdiff import NVRDiffConfig
 
 
 
 
 
 
40
 
41
 
42
  if is_torch_flex_attn_available():
 
160
  class Qwen3Attention(nn.Module):
161
  """Multi-headed attention from 'Attention Is All You Need' paper"""
162
 
163
+ def __init__(self, config: NVRDiffConfig, layer_idx: int):
164
  super().__init__()
165
  self.config = config
166
 
 
306
 
307
 
308
  class Qwen3DecoderLayer(GradientCheckpointingLayer):
309
+ def __init__(self, config: NVRDiffConfig, layer_idx: int):
310
  super().__init__()
311
  self.hidden_size = config.hidden_size
312
  if hasattr(config, 'attn_class'):
 
377
 
378
  @auto_docstring
379
  class Qwen3PreTrainedModel(PreTrainedModel):
380
+ config_class = NVRDiffConfig
381
  base_model_prefix = "model"
382
  supports_gradient_checkpointing = True
383
  _no_split_modules = ["Qwen3DecoderLayer"]
 
405
 
406
 
407
  class Qwen3RotaryEmbedding(nn.Module):
408
+ def __init__(self, config: NVRDiffConfig, device=None):
409
  super().__init__()
410
  # BC: "rope_type" was originally "type"
411
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 
440
 
441
  @auto_docstring
442
  class Qwen3Model(Qwen3PreTrainedModel):
443
+ def __init__(self, config: NVRDiffConfig):
444
  super().__init__(config)
445
  self.config = config
446
 
 
690
  dtype: torch.dtype,
691
  cache_position: torch.Tensor,
692
  batch_size: int,
693
+ config: NVRDiffConfig,
694
  past_key_values: Cache,
695
  ):
696
  """
 
710
  Indices depicting the position of the input sequence tokens in the sequence.
711
  batch_size (`torch.Tensor`):
712
  Batch size.
713
+ config (`NVRDiffConfig`):
714
  The model's configuration class
715
  past_key_values (`Cache`):
716
  The cache class that is being used currently to generate