YongganFu commited on
Commit
dfcb979
·
verified ·
1 Parent(s): 418d9e4

Upload model

Browse files
config.json CHANGED
@@ -22,7 +22,6 @@
22
  "dlm_paradigm": "bidirectional",
23
  "dlm_type": "llada",
24
  "dp_varying_mask_ratio": false,
25
- "dtype": "float32",
26
  "enforce_mask": false,
27
  "eos_token_id": 2,
28
  "global_loss_avg": false,
@@ -57,12 +56,25 @@
57
  "rope_type": "yarn",
58
  "type": "yarn"
59
  },
 
 
 
 
 
 
 
 
 
 
 
 
60
  "rope_theta": 1000000.0,
61
  "seq_length": 8192,
62
  "sliding_window": null,
63
  "tie_word_embeddings": false,
64
  "tok_mask_half_life_ratio": null,
65
- "transformers_version": "5.0.0rc1",
 
66
  "use_cache": false,
67
  "vocab_size": 131072
68
  }
 
22
  "dlm_paradigm": "bidirectional",
23
  "dlm_type": "llada",
24
  "dp_varying_mask_ratio": false,
 
25
  "enforce_mask": false,
26
  "eos_token_id": 2,
27
  "global_loss_avg": false,
 
56
  "rope_type": "yarn",
57
  "type": "yarn"
58
  },
59
+ "rope_scaling": {
60
+ "beta_fast": 32.0,
61
+ "beta_slow": 1.0,
62
+ "factor": 16.0,
63
+ "llama_4_scaling_beta": 0.1,
64
+ "mscale": 1.0,
65
+ "mscale_all_dim": 1.0,
66
+ "original_max_position_embeddings": 16384,
67
+ "rope_theta": 1000000.0,
68
+ "rope_type": "yarn",
69
+ "type": "yarn"
70
+ },
71
  "rope_theta": 1000000.0,
72
  "seq_length": 8192,
73
  "sliding_window": null,
74
  "tie_word_embeddings": false,
75
  "tok_mask_half_life_ratio": null,
76
+ "torch_dtype": "bfloat16",
77
+ "transformers_version": "4.55.4",
78
  "use_cache": false,
79
  "vocab_size": 131072
80
  }
configuration_ministral_dlm.py CHANGED
@@ -155,6 +155,7 @@ class MinistralDLMConfig(PretrainedConfig):
155
  tie_word_embeddings=False,
156
  rope_theta=1000000.0,
157
  rope_parameters=None,
 
158
  attention_bias=False,
159
  attention_dropout=0.0,
160
  mlp_bias=False,
@@ -204,6 +205,7 @@ class MinistralDLMConfig(PretrainedConfig):
204
  self.use_cache = use_cache
205
  self.rope_theta = rope_theta
206
  self.rope_parameters = rope_parameters
 
207
  self.attention_bias = attention_bias
208
  self.attention_dropout = attention_dropout
209
  self.mlp_bias = mlp_bias
 
155
  tie_word_embeddings=False,
156
  rope_theta=1000000.0,
157
  rope_parameters=None,
158
+ rope_scaling=None,
159
  attention_bias=False,
160
  attention_dropout=0.0,
161
  mlp_bias=False,
 
205
  self.use_cache = use_cache
206
  self.rope_theta = rope_theta
207
  self.rope_parameters = rope_parameters
208
+ self.rope_scaling = rope_scaling
209
  self.attention_bias = attention_bias
210
  self.attention_dropout = attention_dropout
211
  self.mlp_bias = mlp_bias
generation_config.json CHANGED
@@ -2,6 +2,6 @@
2
  "_from_model_config": true,
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
- "transformers_version": "5.0.0rc1",
6
  "use_cache": false
7
  }
 
2
  "_from_model_config": true,
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
+ "transformers_version": "4.55.4",
6
  "use_cache": false
7
  }
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59f86ca799393141bddf2d14132ac34374fea32b3cf87bf4bf1ed1839be30d1f
3
+ size 4999767504
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbfe58b42e25e9b11b3c1f34d59de5ae84133285bc2c3dc24ffc9d2280bf1c1b
3
+ size 4999802928
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ade91c2da2eb811769316e84a2bd8034f2c48323e4249cf0ebc5f58a4ca703cf
3
+ size 4915916376
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4dbf0f5cba01173593d5a4cd658bf64bd37eda9e33b418493576eee072013636
3
+ size 2063657632
model.safetensors.index.json ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 8489553920,
4
+ "total_size": 16979107840
5
+ },
6
+ "weight_map": {
7
+ "diffusion_head.weight": "model-00004-of-00004.safetensors",
8
+ "encoder.embed_tokens.weight": "model-00001-of-00004.safetensors",
9
+ "encoder.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
10
+ "encoder.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
11
+ "encoder.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
12
+ "encoder.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
13
+ "encoder.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
14
+ "encoder.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
15
+ "encoder.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
16
+ "encoder.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
17
+ "encoder.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
18
+ "encoder.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
19
+ "encoder.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
20
+ "encoder.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
21
+ "encoder.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
22
+ "encoder.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
23
+ "encoder.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
24
+ "encoder.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
25
+ "encoder.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
26
+ "encoder.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
27
+ "encoder.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
28
+ "encoder.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
29
+ "encoder.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
30
+ "encoder.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
31
+ "encoder.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
32
+ "encoder.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
33
+ "encoder.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
34
+ "encoder.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
35
+ "encoder.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
36
+ "encoder.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
37
+ "encoder.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
38
+ "encoder.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
39
+ "encoder.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
40
+ "encoder.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
41
+ "encoder.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
42
+ "encoder.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
43
+ "encoder.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
44
+ "encoder.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
45
+ "encoder.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
46
+ "encoder.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
47
+ "encoder.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
48
+ "encoder.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
49
+ "encoder.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
50
+ "encoder.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
51
+ "encoder.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
52
+ "encoder.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
53
+ "encoder.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
54
+ "encoder.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
55
+ "encoder.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
56
+ "encoder.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
57
+ "encoder.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
58
+ "encoder.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
59
+ "encoder.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
60
+ "encoder.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
61
+ "encoder.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
62
+ "encoder.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
63
+ "encoder.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
64
+ "encoder.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
65
+ "encoder.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
66
+ "encoder.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
67
+ "encoder.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
68
+ "encoder.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
69
+ "encoder.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
70
+ "encoder.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
71
+ "encoder.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
72
+ "encoder.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
73
+ "encoder.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
74
+ "encoder.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
75
+ "encoder.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
76
+ "encoder.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
77
+ "encoder.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
78
+ "encoder.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
79
+ "encoder.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
80
+ "encoder.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
81
+ "encoder.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
82
+ "encoder.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
83
+ "encoder.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
84
+ "encoder.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
85
+ "encoder.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
86
+ "encoder.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
87
+ "encoder.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
88
+ "encoder.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
89
+ "encoder.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
90
+ "encoder.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
91
+ "encoder.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
92
+ "encoder.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
93
+ "encoder.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
94
+ "encoder.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
95
+ "encoder.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
96
+ "encoder.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
97
+ "encoder.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
98
+ "encoder.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
99
+ "encoder.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
100
+ "encoder.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
101
+ "encoder.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
102
+ "encoder.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
103
+ "encoder.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
104
+ "encoder.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
105
+ "encoder.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
106
+ "encoder.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
107
+ "encoder.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
108
+ "encoder.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors",
109
+ "encoder.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
110
+ "encoder.layers.19.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
111
+ "encoder.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
112
+ "encoder.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
113
+ "encoder.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
114
+ "encoder.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
115
+ "encoder.layers.19.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
116
+ "encoder.layers.19.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
117
+ "encoder.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
118
+ "encoder.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
119
+ "encoder.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
120
+ "encoder.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
121
+ "encoder.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
122
+ "encoder.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
123
+ "encoder.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
124
+ "encoder.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
125
+ "encoder.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
126
+ "encoder.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
127
+ "encoder.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
128
+ "encoder.layers.20.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
129
+ "encoder.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
130
+ "encoder.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
131
+ "encoder.layers.20.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
132
+ "encoder.layers.20.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
133
+ "encoder.layers.20.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
134
+ "encoder.layers.20.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
135
+ "encoder.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
136
+ "encoder.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
137
+ "encoder.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
138
+ "encoder.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
139
+ "encoder.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
140
+ "encoder.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
141
+ "encoder.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
142
+ "encoder.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
143
+ "encoder.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
144
+ "encoder.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
145
+ "encoder.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
146
+ "encoder.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
147
+ "encoder.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
148
+ "encoder.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
149
+ "encoder.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
150
+ "encoder.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
151
+ "encoder.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
152
+ "encoder.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
153
+ "encoder.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
154
+ "encoder.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
155
+ "encoder.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
156
+ "encoder.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
157
+ "encoder.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
158
+ "encoder.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
159
+ "encoder.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
160
+ "encoder.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
161
+ "encoder.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
162
+ "encoder.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
163
+ "encoder.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
164
+ "encoder.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
165
+ "encoder.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
166
+ "encoder.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
167
+ "encoder.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
168
+ "encoder.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
169
+ "encoder.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
170
+ "encoder.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
171
+ "encoder.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
172
+ "encoder.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
173
+ "encoder.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
174
+ "encoder.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
175
+ "encoder.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
176
+ "encoder.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
177
+ "encoder.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
178
+ "encoder.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
179
+ "encoder.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
180
+ "encoder.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
181
+ "encoder.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
182
+ "encoder.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
183
+ "encoder.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
184
+ "encoder.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
185
+ "encoder.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
186
+ "encoder.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
187
+ "encoder.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
188
+ "encoder.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
189
+ "encoder.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
190
+ "encoder.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
191
+ "encoder.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
192
+ "encoder.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
193
+ "encoder.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
194
+ "encoder.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
195
+ "encoder.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
196
+ "encoder.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
197
+ "encoder.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
198
+ "encoder.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
199
+ "encoder.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
200
+ "encoder.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
201
+ "encoder.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
202
+ "encoder.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
203
+ "encoder.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
204
+ "encoder.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
205
+ "encoder.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
206
+ "encoder.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
207
+ "encoder.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
208
+ "encoder.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
209
+ "encoder.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
210
+ "encoder.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
211
+ "encoder.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
212
+ "encoder.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
213
+ "encoder.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
214
+ "encoder.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
215
+ "encoder.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
216
+ "encoder.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
217
+ "encoder.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
218
+ "encoder.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
219
+ "encoder.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
220
+ "encoder.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
221
+ "encoder.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
222
+ "encoder.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
223
+ "encoder.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
224
+ "encoder.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
225
+ "encoder.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
226
+ "encoder.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
227
+ "encoder.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
228
+ "encoder.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
229
+ "encoder.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
230
+ "encoder.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
231
+ "encoder.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
232
+ "encoder.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
233
+ "encoder.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
234
+ "encoder.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
235
+ "encoder.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
236
+ "encoder.layers.31.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
237
+ "encoder.layers.31.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
238
+ "encoder.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
239
+ "encoder.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
240
+ "encoder.layers.31.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
241
+ "encoder.layers.31.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
242
+ "encoder.layers.31.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
243
+ "encoder.layers.32.input_layernorm.weight": "model-00004-of-00004.safetensors",
244
+ "encoder.layers.32.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
245
+ "encoder.layers.32.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
246
+ "encoder.layers.32.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
247
+ "encoder.layers.32.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
248
+ "encoder.layers.32.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
249
+ "encoder.layers.32.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
250
+ "encoder.layers.32.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
251
+ "encoder.layers.32.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
252
+ "encoder.layers.33.input_layernorm.weight": "model-00004-of-00004.safetensors",
253
+ "encoder.layers.33.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
254
+ "encoder.layers.33.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
255
+ "encoder.layers.33.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
256
+ "encoder.layers.33.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
257
+ "encoder.layers.33.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
258
+ "encoder.layers.33.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
259
+ "encoder.layers.33.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
260
+ "encoder.layers.33.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
261
+ "encoder.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
262
+ "encoder.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
263
+ "encoder.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
264
+ "encoder.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
265
+ "encoder.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
266
+ "encoder.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
267
+ "encoder.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
268
+ "encoder.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
269
+ "encoder.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
270
+ "encoder.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
271
+ "encoder.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
272
+ "encoder.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
273
+ "encoder.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
274
+ "encoder.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
275
+ "encoder.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
276
+ "encoder.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
277
+ "encoder.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
278
+ "encoder.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
279
+ "encoder.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
280
+ "encoder.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
281
+ "encoder.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
282
+ "encoder.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
283
+ "encoder.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
284
+ "encoder.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
285
+ "encoder.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
286
+ "encoder.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
287
+ "encoder.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
288
+ "encoder.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
289
+ "encoder.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
290
+ "encoder.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
291
+ "encoder.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
292
+ "encoder.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
293
+ "encoder.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
294
+ "encoder.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
295
+ "encoder.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
296
+ "encoder.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
297
+ "encoder.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors",
298
+ "encoder.layers.8.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
299
+ "encoder.layers.8.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
300
+ "encoder.layers.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
301
+ "encoder.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
302
+ "encoder.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
303
+ "encoder.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
304
+ "encoder.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
305
+ "encoder.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
306
+ "encoder.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
307
+ "encoder.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
308
+ "encoder.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
309
+ "encoder.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
310
+ "encoder.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
311
+ "encoder.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
312
+ "encoder.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
313
+ "encoder.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
314
+ "encoder.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
315
+ "encoder.norm.weight": "model-00004-of-00004.safetensors"
316
+ }
317
+ }
modeling_ministral.py CHANGED
@@ -9,7 +9,8 @@ from transformers.utils.generic import check_model_inputs
9
  from transformers.activations import ACT2FN
10
  from transformers.cache_utils import Cache, DynamicCache
11
  from transformers.generation import GenerationMixin
12
- from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
 
13
  from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
14
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
15
  from transformers.modeling_layers import (
@@ -23,7 +24,7 @@ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_u
23
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
24
  from transformers.processing_utils import Unpack
25
  from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
26
- from transformers.utils.generic import maybe_autocast
27
  from .configuration_ministral_dlm import MinistralDLMConfig
28
 
29
 
@@ -33,8 +34,7 @@ def rotate_half(x):
33
  x2 = x[..., x.shape[-1] // 2 :]
34
  return torch.cat((-x2, x1), dim=-1)
35
 
36
-
37
- @use_kernel_func_from_hub("rotary_pos_emb")
38
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
39
  """Applies Rotary Position Embedding to the query and key tensors.
40
 
@@ -105,7 +105,7 @@ def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_positi
105
  return scaling.unsqueeze(-1)
106
 
107
 
108
- @use_kernelized_func(apply_rotary_pos_emb)
109
  class Ministral3Attention(nn.Module):
110
  """Multi-headed attention from 'Attention Is All You Need' paper"""
111
 
@@ -356,12 +356,13 @@ class Ministral3RotaryEmbedding(nn.Module):
356
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
357
  position_ids_expanded = position_ids[:, None, :].float()
358
 
359
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
360
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
361
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
362
- emb = torch.cat((freqs, freqs), dim=-1)
363
- cos = emb.cos() * self.attention_scaling
364
- sin = emb.sin() * self.attention_scaling
 
365
 
366
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
367
 
@@ -404,7 +405,8 @@ class Ministral3Model(Ministral3PreTrainedModel):
404
  inputs_embeds = self.embed_tokens(input_ids)
405
 
406
  if use_cache and past_key_values is None:
407
- past_key_values = DynamicCache(config=self.config)
 
408
 
409
  if cache_position is None:
410
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
9
  from transformers.activations import ACT2FN
10
  from transformers.cache_utils import Cache, DynamicCache
11
  from transformers.generation import GenerationMixin
12
+ # from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
13
+ from transformers.integrations import use_kernel_forward_from_hub
14
  from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
15
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
  from transformers.modeling_layers import (
 
24
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
25
  from transformers.processing_utils import Unpack
26
  from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
27
+ # from transformers.utils.generic import maybe_autocast
28
  from .configuration_ministral_dlm import MinistralDLMConfig
29
 
30
 
 
34
  x2 = x[..., x.shape[-1] // 2 :]
35
  return torch.cat((-x2, x1), dim=-1)
36
 
37
+ # @use_kernel_func_from_hub("rotary_pos_emb")
 
38
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
39
  """Applies Rotary Position Embedding to the query and key tensors.
40
 
 
105
  return scaling.unsqueeze(-1)
106
 
107
 
108
+ # @use_kernelized_func(apply_rotary_pos_emb)
109
  class Ministral3Attention(nn.Module):
110
  """Multi-headed attention from 'Attention Is All You Need' paper"""
111
 
 
356
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
357
  position_ids_expanded = position_ids[:, None, :].float()
358
 
359
+ # device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
360
+ # with maybe_autocast(device_type=device_type, enabled=False): # Force float32
361
+
362
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
363
+ emb = torch.cat((freqs, freqs), dim=-1)
364
+ cos = emb.cos() * self.attention_scaling
365
+ sin = emb.sin() * self.attention_scaling
366
 
367
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
368
 
 
405
  inputs_embeds = self.embed_tokens(input_ids)
406
 
407
  if use_cache and past_key_values is None:
408
+ # past_key_values = DynamicCache(config=self.config)
409
+ past_key_values = DynamicCache()
410
 
411
  if cache_position is None:
412
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
modeling_ministral_dlm.py CHANGED
@@ -1,5 +1,4 @@
1
  import copy
2
- from dataclasses import dataclass
3
  from typing import Callable, Optional, Tuple, Union
4
  import random
5
  import os
@@ -11,7 +10,6 @@ import torch
11
  import torch.nn.functional as F
12
  from torch import nn
13
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
14
- from transformers.utils import ModelOutput
15
 
16
  from torch.nn.attention.flex_attention import flex_attention, create_block_mask
17
 
@@ -31,17 +29,6 @@ from .chat_utils import generate_with_prefix_cache_block_diff
31
  from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
32
  from .configuration_ministral_dlm import MinistralDLMConfig
33
 
34
-
35
- @dataclass
36
- class MinistralDiffOutputWithPast(ModelOutput):
37
- loss: torch.FloatTensor | None = None
38
- logits: torch.FloatTensor | None = None
39
- causal_logits: torch.FloatTensor | None = None
40
- past_key_values: Cache | None = None
41
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
42
- attentions: tuple[torch.FloatTensor, ...] | None = None
43
-
44
-
45
  # @torch.compile(dynamic=True, mode="reduce-overhead")
46
  # @torch.compile(mode="default")
47
  # @torch.compile(fullgraph=True, mode="reduce-overhead", dynamic=False)
@@ -492,7 +479,6 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
492
  loss_mask: Optional[torch.Tensor] = None,
493
  ce_loss_weight: float = 1.0,
494
  output_last_hidden_states_only: bool = False,
495
- skip_loss: bool = False,
496
  **kwargs,
497
  ) -> CausalLMOutputWithPast:
498
 
@@ -579,7 +565,7 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
579
  logits = logits[:, :input_ids_len]
580
 
581
  loss = None
582
- if labels is not None and not skip_loss:
583
  if self.config.dlm_paradigm == 'autoregressive':
584
  shift_logits = logits[..., :-1, :].contiguous()
585
  shift_labels = labels[..., 1:].contiguous()
@@ -716,10 +702,9 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
716
  else:
717
  loss = (loss, num_mask_tokens)
718
 
719
- return MinistralDiffOutputWithPast(
720
  loss=loss if not is_teacher else logits,
721
  logits=logits,
722
- causal_logits=causal_logits,
723
  past_key_values=enc_out.past_key_values,
724
  hidden_states=None,
725
  attentions=None,
 
1
  import copy
 
2
  from typing import Callable, Optional, Tuple, Union
3
  import random
4
  import os
 
10
  import torch.nn.functional as F
11
  from torch import nn
12
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
 
13
 
14
  from torch.nn.attention.flex_attention import flex_attention, create_block_mask
15
 
 
29
  from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
30
  from .configuration_ministral_dlm import MinistralDLMConfig
31
 
 
 
 
 
 
 
 
 
 
 
 
32
  # @torch.compile(dynamic=True, mode="reduce-overhead")
33
  # @torch.compile(mode="default")
34
  # @torch.compile(fullgraph=True, mode="reduce-overhead", dynamic=False)
 
479
  loss_mask: Optional[torch.Tensor] = None,
480
  ce_loss_weight: float = 1.0,
481
  output_last_hidden_states_only: bool = False,
 
482
  **kwargs,
483
  ) -> CausalLMOutputWithPast:
484
 
 
565
  logits = logits[:, :input_ids_len]
566
 
567
  loss = None
568
+ if labels is not None:
569
  if self.config.dlm_paradigm == 'autoregressive':
570
  shift_logits = logits[..., :-1, :].contiguous()
571
  shift_labels = labels[..., 1:].contiguous()
 
702
  else:
703
  loss = (loss, num_mask_tokens)
704
 
705
+ return CausalLMOutputWithPast(
706
  loss=loss if not is_teacher else logits,
707
  logits=logits,
 
708
  past_key_values=enc_out.past_key_values,
709
  hidden_states=None,
710
  attentions=None,