smy111 commited on
Commit
9377294
·
verified ·
1 Parent(s): 83ed10f

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. config.json +1637 -2
  3. configuration_qwen3_moe.py +238 -0
  4. modeling_qwen3_moe.py +31 -10
README.md CHANGED
@@ -6,7 +6,7 @@
6
  - **Full Attention:** 15%
7
  - **Version:** 1.0
8
 
9
- <img src="./headwise.png" alt="screenshot" width="60%">
10
 
11
  RTPurbo uses hybrid HeadWise Attention to compress the Qwen3Coder model. Specifically, it divides attention into two parts according to attention type:
12
 
 
6
  - **Full Attention:** 15%
7
  - **Version:** 1.0
8
 
9
+ <img src="./headwise.png" alt="screenshot">
10
 
11
  RTPurbo uses hybrid HeadWise Attention to compress the Qwen3Coder model. Specifically, it divides attention into two parts according to attention type:
12
 
config.json CHANGED
@@ -3,7 +3,8 @@
3
  "Qwen3MoeForCausalLM"
4
  ],
5
  "auto_map": {
6
- "AutoModelForCausalLM": "modeling_qwen3_moe.Qwen3MoeForCausalLM"
 
7
  },
8
  "attention_bias": false,
9
  "attention_dropout": 0.0,
@@ -39,5 +40,1639 @@
39
  "use_cache": true,
40
  "use_qk_norm": true,
41
  "use_sliding_window": false,
42
- "vocab_size": 151936
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  }
 
3
  "Qwen3MoeForCausalLM"
4
  ],
5
  "auto_map": {
6
+ "AutoModelForCausalLM": "modeling_qwen3_moe.Qwen3MoeForCausalLM",
7
+ "AutoConfig": "configuration_qwen3_moe.Qwen3MoeConfig"
8
  },
9
  "attention_bias": false,
10
  "attention_dropout": 0.0,
 
40
  "use_cache": true,
41
  "use_qk_norm": true,
42
  "use_sliding_window": false,
43
+ "vocab_size": 151936,
44
+ "headwise_config": {
45
+ "0": [
46
+ 0,
47
+ 0,
48
+ 0,
49
+ 0,
50
+ 0,
51
+ 0,
52
+ 0,
53
+ 0,
54
+ 0,
55
+ 0,
56
+ 0,
57
+ 0,
58
+ 0,
59
+ 0,
60
+ 0,
61
+ 0,
62
+ 0,
63
+ 0,
64
+ 0,
65
+ 0,
66
+ 0,
67
+ 0,
68
+ 0,
69
+ 0,
70
+ 0,
71
+ 0,
72
+ 0,
73
+ 0,
74
+ 0,
75
+ 0,
76
+ 0,
77
+ 0
78
+ ],
79
+ "1": [
80
+ 0,
81
+ 0,
82
+ 0,
83
+ 0,
84
+ 0,
85
+ 0,
86
+ 0,
87
+ 0,
88
+ 0,
89
+ 0,
90
+ 0,
91
+ 0,
92
+ 0,
93
+ 0,
94
+ 1,
95
+ 0,
96
+ 0,
97
+ 0,
98
+ 0,
99
+ 0,
100
+ 0,
101
+ 0,
102
+ 0,
103
+ 0,
104
+ 0,
105
+ 0,
106
+ 0,
107
+ 0,
108
+ 0,
109
+ 0,
110
+ 0,
111
+ 0
112
+ ],
113
+ "2": [
114
+ 0,
115
+ 0,
116
+ 0,
117
+ 0,
118
+ 0,
119
+ 0,
120
+ 0,
121
+ 0,
122
+ 0,
123
+ 0,
124
+ 0,
125
+ 1,
126
+ 0,
127
+ 0,
128
+ 0,
129
+ 0,
130
+ 1,
131
+ 0,
132
+ 0,
133
+ 1,
134
+ 0,
135
+ 0,
136
+ 1,
137
+ 0,
138
+ 0,
139
+ 0,
140
+ 0,
141
+ 0,
142
+ 0,
143
+ 0,
144
+ 0,
145
+ 0
146
+ ],
147
+ "3": [
148
+ 0,
149
+ 0,
150
+ 0,
151
+ 0,
152
+ 0,
153
+ 0,
154
+ 0,
155
+ 0,
156
+ 0,
157
+ 0,
158
+ 0,
159
+ 0,
160
+ 0,
161
+ 0,
162
+ 0,
163
+ 0,
164
+ 0,
165
+ 0,
166
+ 0,
167
+ 0,
168
+ 0,
169
+ 0,
170
+ 0,
171
+ 0,
172
+ 0,
173
+ 0,
174
+ 0,
175
+ 0,
176
+ 0,
177
+ 0,
178
+ 0,
179
+ 0
180
+ ],
181
+ "4": [
182
+ 0,
183
+ 0,
184
+ 0,
185
+ 0,
186
+ 0,
187
+ 0,
188
+ 0,
189
+ 0,
190
+ 0,
191
+ 0,
192
+ 0,
193
+ 0,
194
+ 0,
195
+ 0,
196
+ 0,
197
+ 0,
198
+ 0,
199
+ 0,
200
+ 0,
201
+ 0,
202
+ 0,
203
+ 0,
204
+ 0,
205
+ 0,
206
+ 0,
207
+ 1,
208
+ 0,
209
+ 1,
210
+ 1,
211
+ 0,
212
+ 0,
213
+ 0
214
+ ],
215
+ "5": [
216
+ 0,
217
+ 0,
218
+ 0,
219
+ 0,
220
+ 0,
221
+ 0,
222
+ 0,
223
+ 0,
224
+ 0,
225
+ 0,
226
+ 0,
227
+ 0,
228
+ 0,
229
+ 0,
230
+ 0,
231
+ 0,
232
+ 0,
233
+ 0,
234
+ 0,
235
+ 0,
236
+ 0,
237
+ 0,
238
+ 0,
239
+ 0,
240
+ 0,
241
+ 0,
242
+ 0,
243
+ 0,
244
+ 0,
245
+ 0,
246
+ 0,
247
+ 0
248
+ ],
249
+ "6": [
250
+ 0,
251
+ 0,
252
+ 0,
253
+ 0,
254
+ 0,
255
+ 0,
256
+ 0,
257
+ 0,
258
+ 0,
259
+ 0,
260
+ 0,
261
+ 0,
262
+ 0,
263
+ 0,
264
+ 0,
265
+ 0,
266
+ 0,
267
+ 0,
268
+ 0,
269
+ 0,
270
+ 1,
271
+ 0,
272
+ 0,
273
+ 0,
274
+ 1,
275
+ 1,
276
+ 0,
277
+ 0,
278
+ 0,
279
+ 0,
280
+ 0,
281
+ 0
282
+ ],
283
+ "7": [
284
+ 0,
285
+ 0,
286
+ 0,
287
+ 0,
288
+ 0,
289
+ 0,
290
+ 0,
291
+ 0,
292
+ 0,
293
+ 0,
294
+ 0,
295
+ 0,
296
+ 0,
297
+ 0,
298
+ 0,
299
+ 0,
300
+ 0,
301
+ 0,
302
+ 0,
303
+ 0,
304
+ 0,
305
+ 0,
306
+ 0,
307
+ 0,
308
+ 0,
309
+ 0,
310
+ 0,
311
+ 0,
312
+ 0,
313
+ 0,
314
+ 0,
315
+ 0
316
+ ],
317
+ "8": [
318
+ 0,
319
+ 0,
320
+ 0,
321
+ 0,
322
+ 0,
323
+ 0,
324
+ 0,
325
+ 0,
326
+ 0,
327
+ 0,
328
+ 0,
329
+ 0,
330
+ 1,
331
+ 0,
332
+ 0,
333
+ 0,
334
+ 0,
335
+ 0,
336
+ 0,
337
+ 0,
338
+ 0,
339
+ 0,
340
+ 0,
341
+ 0,
342
+ 0,
343
+ 0,
344
+ 0,
345
+ 0,
346
+ 0,
347
+ 0,
348
+ 0,
349
+ 0
350
+ ],
351
+ "9": [
352
+ 0,
353
+ 0,
354
+ 0,
355
+ 0,
356
+ 0,
357
+ 0,
358
+ 0,
359
+ 0,
360
+ 0,
361
+ 0,
362
+ 0,
363
+ 0,
364
+ 0,
365
+ 0,
366
+ 0,
367
+ 0,
368
+ 0,
369
+ 0,
370
+ 0,
371
+ 0,
372
+ 0,
373
+ 0,
374
+ 0,
375
+ 0,
376
+ 0,
377
+ 0,
378
+ 0,
379
+ 0,
380
+ 0,
381
+ 0,
382
+ 0,
383
+ 0
384
+ ],
385
+ "10": [
386
+ 0,
387
+ 0,
388
+ 0,
389
+ 0,
390
+ 1,
391
+ 0,
392
+ 1,
393
+ 0,
394
+ 0,
395
+ 1,
396
+ 0,
397
+ 0,
398
+ 0,
399
+ 0,
400
+ 0,
401
+ 1,
402
+ 0,
403
+ 0,
404
+ 0,
405
+ 0,
406
+ 0,
407
+ 0,
408
+ 1,
409
+ 0,
410
+ 0,
411
+ 0,
412
+ 0,
413
+ 0,
414
+ 0,
415
+ 0,
416
+ 0,
417
+ 0
418
+ ],
419
+ "11": [
420
+ 0,
421
+ 0,
422
+ 0,
423
+ 0,
424
+ 0,
425
+ 0,
426
+ 0,
427
+ 0,
428
+ 0,
429
+ 0,
430
+ 0,
431
+ 0,
432
+ 0,
433
+ 0,
434
+ 0,
435
+ 0,
436
+ 0,
437
+ 0,
438
+ 0,
439
+ 0,
440
+ 0,
441
+ 0,
442
+ 0,
443
+ 1,
444
+ 0,
445
+ 0,
446
+ 1,
447
+ 0,
448
+ 1,
449
+ 0,
450
+ 0,
451
+ 0
452
+ ],
453
+ "12": [
454
+ 0,
455
+ 0,
456
+ 0,
457
+ 0,
458
+ 0,
459
+ 0,
460
+ 0,
461
+ 0,
462
+ 0,
463
+ 0,
464
+ 0,
465
+ 0,
466
+ 0,
467
+ 0,
468
+ 0,
469
+ 0,
470
+ 0,
471
+ 0,
472
+ 0,
473
+ 0,
474
+ 0,
475
+ 0,
476
+ 0,
477
+ 0,
478
+ 1,
479
+ 1,
480
+ 1,
481
+ 1,
482
+ 1,
483
+ 0,
484
+ 0,
485
+ 0
486
+ ],
487
+ "13": [
488
+ 0,
489
+ 0,
490
+ 0,
491
+ 0,
492
+ 0,
493
+ 0,
494
+ 0,
495
+ 0,
496
+ 0,
497
+ 0,
498
+ 0,
499
+ 0,
500
+ 0,
501
+ 0,
502
+ 0,
503
+ 0,
504
+ 0,
505
+ 1,
506
+ 0,
507
+ 0,
508
+ 1,
509
+ 1,
510
+ 0,
511
+ 1,
512
+ 0,
513
+ 0,
514
+ 0,
515
+ 0,
516
+ 0,
517
+ 0,
518
+ 0,
519
+ 0
520
+ ],
521
+ "14": [
522
+ 0,
523
+ 0,
524
+ 0,
525
+ 0,
526
+ 0,
527
+ 0,
528
+ 0,
529
+ 0,
530
+ 0,
531
+ 0,
532
+ 0,
533
+ 0,
534
+ 0,
535
+ 0,
536
+ 0,
537
+ 0,
538
+ 0,
539
+ 0,
540
+ 0,
541
+ 0,
542
+ 0,
543
+ 0,
544
+ 0,
545
+ 0,
546
+ 1,
547
+ 0,
548
+ 1,
549
+ 0,
550
+ 0,
551
+ 0,
552
+ 0,
553
+ 0
554
+ ],
555
+ "15": [
556
+ 0,
557
+ 0,
558
+ 0,
559
+ 0,
560
+ 1,
561
+ 0,
562
+ 0,
563
+ 0,
564
+ 0,
565
+ 0,
566
+ 0,
567
+ 0,
568
+ 0,
569
+ 0,
570
+ 1,
571
+ 0,
572
+ 0,
573
+ 1,
574
+ 1,
575
+ 0,
576
+ 0,
577
+ 1,
578
+ 1,
579
+ 1,
580
+ 0,
581
+ 0,
582
+ 0,
583
+ 0,
584
+ 0,
585
+ 0,
586
+ 0,
587
+ 0
588
+ ],
589
+ "16": [
590
+ 0,
591
+ 0,
592
+ 0,
593
+ 0,
594
+ 0,
595
+ 0,
596
+ 0,
597
+ 0,
598
+ 0,
599
+ 0,
600
+ 0,
601
+ 1,
602
+ 0,
603
+ 0,
604
+ 0,
605
+ 0,
606
+ 0,
607
+ 0,
608
+ 0,
609
+ 0,
610
+ 0,
611
+ 0,
612
+ 0,
613
+ 0,
614
+ 0,
615
+ 0,
616
+ 0,
617
+ 0,
618
+ 0,
619
+ 0,
620
+ 0,
621
+ 0
622
+ ],
623
+ "17": [
624
+ 0,
625
+ 0,
626
+ 0,
627
+ 1,
628
+ 0,
629
+ 0,
630
+ 0,
631
+ 0,
632
+ 0,
633
+ 0,
634
+ 0,
635
+ 0,
636
+ 0,
637
+ 0,
638
+ 0,
639
+ 0,
640
+ 0,
641
+ 0,
642
+ 0,
643
+ 1,
644
+ 1,
645
+ 1,
646
+ 1,
647
+ 1,
648
+ 0,
649
+ 0,
650
+ 0,
651
+ 0,
652
+ 0,
653
+ 0,
654
+ 0,
655
+ 0
656
+ ],
657
+ "18": [
658
+ 0,
659
+ 0,
660
+ 0,
661
+ 0,
662
+ 1,
663
+ 0,
664
+ 0,
665
+ 0,
666
+ 1,
667
+ 0,
668
+ 0,
669
+ 0,
670
+ 0,
671
+ 0,
672
+ 0,
673
+ 0,
674
+ 0,
675
+ 0,
676
+ 0,
677
+ 0,
678
+ 1,
679
+ 0,
680
+ 0,
681
+ 0,
682
+ 1,
683
+ 1,
684
+ 0,
685
+ 0,
686
+ 1,
687
+ 1,
688
+ 0,
689
+ 0
690
+ ],
691
+ "19": [
692
+ 0,
693
+ 0,
694
+ 0,
695
+ 0,
696
+ 0,
697
+ 0,
698
+ 0,
699
+ 0,
700
+ 0,
701
+ 0,
702
+ 0,
703
+ 0,
704
+ 0,
705
+ 0,
706
+ 0,
707
+ 0,
708
+ 0,
709
+ 0,
710
+ 0,
711
+ 0,
712
+ 0,
713
+ 0,
714
+ 0,
715
+ 0,
716
+ 0,
717
+ 0,
718
+ 0,
719
+ 0,
720
+ 0,
721
+ 0,
722
+ 0,
723
+ 0
724
+ ],
725
+ "20": [
726
+ 0,
727
+ 0,
728
+ 0,
729
+ 0,
730
+ 0,
731
+ 0,
732
+ 0,
733
+ 1,
734
+ 0,
735
+ 0,
736
+ 0,
737
+ 0,
738
+ 1,
739
+ 0,
740
+ 1,
741
+ 0,
742
+ 0,
743
+ 0,
744
+ 0,
745
+ 0,
746
+ 0,
747
+ 0,
748
+ 0,
749
+ 0,
750
+ 0,
751
+ 0,
752
+ 0,
753
+ 0,
754
+ 0,
755
+ 0,
756
+ 0,
757
+ 0
758
+ ],
759
+ "21": [
760
+ 0,
761
+ 0,
762
+ 0,
763
+ 0,
764
+ 0,
765
+ 0,
766
+ 0,
767
+ 0,
768
+ 0,
769
+ 0,
770
+ 0,
771
+ 0,
772
+ 0,
773
+ 0,
774
+ 0,
775
+ 0,
776
+ 0,
777
+ 0,
778
+ 0,
779
+ 0,
780
+ 1,
781
+ 0,
782
+ 0,
783
+ 0,
784
+ 0,
785
+ 0,
786
+ 0,
787
+ 0,
788
+ 0,
789
+ 1,
790
+ 0,
791
+ 0
792
+ ],
793
+ "22": [
794
+ 1,
795
+ 0,
796
+ 1,
797
+ 0,
798
+ 1,
799
+ 0,
800
+ 1,
801
+ 1,
802
+ 0,
803
+ 1,
804
+ 0,
805
+ 0,
806
+ 0,
807
+ 0,
808
+ 0,
809
+ 1,
810
+ 0,
811
+ 0,
812
+ 0,
813
+ 0,
814
+ 0,
815
+ 0,
816
+ 1,
817
+ 0,
818
+ 0,
819
+ 0,
820
+ 0,
821
+ 0,
822
+ 0,
823
+ 1,
824
+ 0,
825
+ 0
826
+ ],
827
+ "23": [
828
+ 0,
829
+ 0,
830
+ 0,
831
+ 0,
832
+ 0,
833
+ 0,
834
+ 0,
835
+ 0,
836
+ 0,
837
+ 0,
838
+ 0,
839
+ 0,
840
+ 0,
841
+ 0,
842
+ 0,
843
+ 0,
844
+ 0,
845
+ 0,
846
+ 0,
847
+ 0,
848
+ 0,
849
+ 0,
850
+ 0,
851
+ 1,
852
+ 0,
853
+ 0,
854
+ 1,
855
+ 0,
856
+ 0,
857
+ 0,
858
+ 0,
859
+ 0
860
+ ],
861
+ "24": [
862
+ 0,
863
+ 0,
864
+ 0,
865
+ 0,
866
+ 0,
867
+ 0,
868
+ 1,
869
+ 0,
870
+ 0,
871
+ 0,
872
+ 0,
873
+ 0,
874
+ 0,
875
+ 0,
876
+ 0,
877
+ 0,
878
+ 0,
879
+ 0,
880
+ 0,
881
+ 0,
882
+ 0,
883
+ 0,
884
+ 0,
885
+ 0,
886
+ 1,
887
+ 1,
888
+ 1,
889
+ 1,
890
+ 1,
891
+ 0,
892
+ 0,
893
+ 0
894
+ ],
895
+ "25": [
896
+ 0,
897
+ 0,
898
+ 0,
899
+ 0,
900
+ 0,
901
+ 0,
902
+ 0,
903
+ 0,
904
+ 0,
905
+ 0,
906
+ 0,
907
+ 0,
908
+ 0,
909
+ 0,
910
+ 0,
911
+ 0,
912
+ 0,
913
+ 1,
914
+ 1,
915
+ 0,
916
+ 0,
917
+ 1,
918
+ 0,
919
+ 1,
920
+ 0,
921
+ 0,
922
+ 0,
923
+ 0,
924
+ 0,
925
+ 0,
926
+ 0,
927
+ 0
928
+ ],
929
+ "26": [
930
+ 0,
931
+ 0,
932
+ 1,
933
+ 0,
934
+ 0,
935
+ 1,
936
+ 0,
937
+ 0,
938
+ 0,
939
+ 0,
940
+ 0,
941
+ 0,
942
+ 0,
943
+ 0,
944
+ 0,
945
+ 0,
946
+ 0,
947
+ 0,
948
+ 0,
949
+ 0,
950
+ 0,
951
+ 0,
952
+ 0,
953
+ 0,
954
+ 1,
955
+ 1,
956
+ 1,
957
+ 0,
958
+ 0,
959
+ 1,
960
+ 0,
961
+ 0
962
+ ],
963
+ "27": [
964
+ 0,
965
+ 0,
966
+ 0,
967
+ 0,
968
+ 1,
969
+ 1,
970
+ 0,
971
+ 1,
972
+ 0,
973
+ 0,
974
+ 0,
975
+ 0,
976
+ 0,
977
+ 0,
978
+ 1,
979
+ 0,
980
+ 1,
981
+ 1,
982
+ 1,
983
+ 0,
984
+ 1,
985
+ 0,
986
+ 1,
987
+ 1,
988
+ 0,
989
+ 0,
990
+ 0,
991
+ 0,
992
+ 0,
993
+ 0,
994
+ 0,
995
+ 1
996
+ ],
997
+ "28": [
998
+ 0,
999
+ 0,
1000
+ 0,
1001
+ 0,
1002
+ 0,
1003
+ 0,
1004
+ 0,
1005
+ 0,
1006
+ 0,
1007
+ 0,
1008
+ 1,
1009
+ 0,
1010
+ 0,
1011
+ 0,
1012
+ 0,
1013
+ 0,
1014
+ 0,
1015
+ 0,
1016
+ 0,
1017
+ 0,
1018
+ 0,
1019
+ 0,
1020
+ 0,
1021
+ 0,
1022
+ 0,
1023
+ 1,
1024
+ 0,
1025
+ 0,
1026
+ 0,
1027
+ 0,
1028
+ 0,
1029
+ 0
1030
+ ],
1031
+ "29": [
1032
+ 0,
1033
+ 0,
1034
+ 0,
1035
+ 0,
1036
+ 0,
1037
+ 0,
1038
+ 0,
1039
+ 0,
1040
+ 0,
1041
+ 0,
1042
+ 0,
1043
+ 0,
1044
+ 0,
1045
+ 0,
1046
+ 0,
1047
+ 0,
1048
+ 0,
1049
+ 0,
1050
+ 0,
1051
+ 1,
1052
+ 1,
1053
+ 0,
1054
+ 1,
1055
+ 0,
1056
+ 0,
1057
+ 0,
1058
+ 0,
1059
+ 0,
1060
+ 0,
1061
+ 0,
1062
+ 0,
1063
+ 0
1064
+ ],
1065
+ "30": [
1066
+ 0,
1067
+ 0,
1068
+ 0,
1069
+ 0,
1070
+ 0,
1071
+ 0,
1072
+ 0,
1073
+ 0,
1074
+ 0,
1075
+ 0,
1076
+ 0,
1077
+ 0,
1078
+ 0,
1079
+ 0,
1080
+ 0,
1081
+ 0,
1082
+ 0,
1083
+ 0,
1084
+ 0,
1085
+ 1,
1086
+ 1,
1087
+ 0,
1088
+ 0,
1089
+ 1,
1090
+ 1,
1091
+ 0,
1092
+ 1,
1093
+ 0,
1094
+ 1,
1095
+ 1,
1096
+ 0,
1097
+ 0
1098
+ ],
1099
+ "31": [
1100
+ 0,
1101
+ 0,
1102
+ 0,
1103
+ 0,
1104
+ 0,
1105
+ 0,
1106
+ 0,
1107
+ 0,
1108
+ 0,
1109
+ 0,
1110
+ 0,
1111
+ 0,
1112
+ 0,
1113
+ 0,
1114
+ 0,
1115
+ 0,
1116
+ 0,
1117
+ 0,
1118
+ 0,
1119
+ 0,
1120
+ 0,
1121
+ 0,
1122
+ 0,
1123
+ 0,
1124
+ 0,
1125
+ 0,
1126
+ 0,
1127
+ 1,
1128
+ 0,
1129
+ 0,
1130
+ 0,
1131
+ 0
1132
+ ],
1133
+ "32": [
1134
+ 0,
1135
+ 0,
1136
+ 0,
1137
+ 0,
1138
+ 0,
1139
+ 0,
1140
+ 0,
1141
+ 0,
1142
+ 1,
1143
+ 0,
1144
+ 0,
1145
+ 0,
1146
+ 0,
1147
+ 0,
1148
+ 1,
1149
+ 0,
1150
+ 0,
1151
+ 0,
1152
+ 0,
1153
+ 0,
1154
+ 0,
1155
+ 0,
1156
+ 0,
1157
+ 0,
1158
+ 1,
1159
+ 0,
1160
+ 0,
1161
+ 0,
1162
+ 0,
1163
+ 0,
1164
+ 0,
1165
+ 0
1166
+ ],
1167
+ "33": [
1168
+ 0,
1169
+ 1,
1170
+ 0,
1171
+ 0,
1172
+ 0,
1173
+ 0,
1174
+ 0,
1175
+ 0,
1176
+ 0,
1177
+ 0,
1178
+ 0,
1179
+ 0,
1180
+ 0,
1181
+ 0,
1182
+ 0,
1183
+ 0,
1184
+ 0,
1185
+ 0,
1186
+ 0,
1187
+ 0,
1188
+ 1,
1189
+ 0,
1190
+ 0,
1191
+ 0,
1192
+ 0,
1193
+ 0,
1194
+ 0,
1195
+ 0,
1196
+ 0,
1197
+ 0,
1198
+ 0,
1199
+ 0
1200
+ ],
1201
+ "34": [
1202
+ 1,
1203
+ 0,
1204
+ 1,
1205
+ 1,
1206
+ 1,
1207
+ 0,
1208
+ 1,
1209
+ 1,
1210
+ 0,
1211
+ 1,
1212
+ 0,
1213
+ 0,
1214
+ 0,
1215
+ 0,
1216
+ 0,
1217
+ 1,
1218
+ 0,
1219
+ 0,
1220
+ 0,
1221
+ 0,
1222
+ 0,
1223
+ 0,
1224
+ 1,
1225
+ 0,
1226
+ 0,
1227
+ 0,
1228
+ 0,
1229
+ 0,
1230
+ 0,
1231
+ 0,
1232
+ 0,
1233
+ 0
1234
+ ],
1235
+ "35": [
1236
+ 0,
1237
+ 0,
1238
+ 0,
1239
+ 0,
1240
+ 0,
1241
+ 0,
1242
+ 0,
1243
+ 0,
1244
+ 0,
1245
+ 0,
1246
+ 0,
1247
+ 0,
1248
+ 0,
1249
+ 0,
1250
+ 0,
1251
+ 0,
1252
+ 0,
1253
+ 0,
1254
+ 0,
1255
+ 0,
1256
+ 0,
1257
+ 0,
1258
+ 0,
1259
+ 0,
1260
+ 0,
1261
+ 0,
1262
+ 0,
1263
+ 0,
1264
+ 0,
1265
+ 0,
1266
+ 0,
1267
+ 0
1268
+ ],
1269
+ "36": [
1270
+ 0,
1271
+ 0,
1272
+ 0,
1273
+ 0,
1274
+ 0,
1275
+ 0,
1276
+ 0,
1277
+ 0,
1278
+ 1,
1279
+ 0,
1280
+ 0,
1281
+ 0,
1282
+ 0,
1283
+ 0,
1284
+ 0,
1285
+ 0,
1286
+ 0,
1287
+ 1,
1288
+ 0,
1289
+ 0,
1290
+ 0,
1291
+ 0,
1292
+ 0,
1293
+ 0,
1294
+ 1,
1295
+ 1,
1296
+ 1,
1297
+ 1,
1298
+ 1,
1299
+ 0,
1300
+ 0,
1301
+ 1
1302
+ ],
1303
+ "37": [
1304
+ 0,
1305
+ 0,
1306
+ 0,
1307
+ 1,
1308
+ 0,
1309
+ 0,
1310
+ 0,
1311
+ 0,
1312
+ 0,
1313
+ 0,
1314
+ 0,
1315
+ 0,
1316
+ 0,
1317
+ 0,
1318
+ 0,
1319
+ 0,
1320
+ 0,
1321
+ 1,
1322
+ 1,
1323
+ 0,
1324
+ 0,
1325
+ 1,
1326
+ 1,
1327
+ 1,
1328
+ 0,
1329
+ 0,
1330
+ 0,
1331
+ 0,
1332
+ 0,
1333
+ 0,
1334
+ 0,
1335
+ 0
1336
+ ],
1337
+ "38": [
1338
+ 0,
1339
+ 0,
1340
+ 0,
1341
+ 0,
1342
+ 0,
1343
+ 1,
1344
+ 0,
1345
+ 0,
1346
+ 0,
1347
+ 0,
1348
+ 0,
1349
+ 0,
1350
+ 0,
1351
+ 0,
1352
+ 0,
1353
+ 0,
1354
+ 0,
1355
+ 0,
1356
+ 1,
1357
+ 0,
1358
+ 0,
1359
+ 0,
1360
+ 0,
1361
+ 0,
1362
+ 1,
1363
+ 1,
1364
+ 1,
1365
+ 1,
1366
+ 0,
1367
+ 1,
1368
+ 0,
1369
+ 0
1370
+ ],
1371
+ "39": [
1372
+ 0,
1373
+ 0,
1374
+ 0,
1375
+ 0,
1376
+ 1,
1377
+ 1,
1378
+ 0,
1379
+ 1,
1380
+ 0,
1381
+ 0,
1382
+ 1,
1383
+ 0,
1384
+ 1,
1385
+ 0,
1386
+ 1,
1387
+ 0,
1388
+ 0,
1389
+ 1,
1390
+ 1,
1391
+ 0,
1392
+ 1,
1393
+ 1,
1394
+ 0,
1395
+ 1,
1396
+ 0,
1397
+ 0,
1398
+ 1,
1399
+ 0,
1400
+ 0,
1401
+ 0,
1402
+ 0,
1403
+ 1
1404
+ ],
1405
+ "40": [
1406
+ 0,
1407
+ 0,
1408
+ 0,
1409
+ 0,
1410
+ 0,
1411
+ 0,
1412
+ 0,
1413
+ 0,
1414
+ 0,
1415
+ 0,
1416
+ 1,
1417
+ 0,
1418
+ 0,
1419
+ 0,
1420
+ 1,
1421
+ 0,
1422
+ 0,
1423
+ 0,
1424
+ 0,
1425
+ 0,
1426
+ 0,
1427
+ 0,
1428
+ 0,
1429
+ 0,
1430
+ 0,
1431
+ 1,
1432
+ 0,
1433
+ 1,
1434
+ 1,
1435
+ 0,
1436
+ 0,
1437
+ 0
1438
+ ],
1439
+ "41": [
1440
+ 0,
1441
+ 0,
1442
+ 0,
1443
+ 0,
1444
+ 0,
1445
+ 0,
1446
+ 0,
1447
+ 0,
1448
+ 0,
1449
+ 0,
1450
+ 0,
1451
+ 0,
1452
+ 0,
1453
+ 0,
1454
+ 0,
1455
+ 0,
1456
+ 0,
1457
+ 1,
1458
+ 0,
1459
+ 1,
1460
+ 1,
1461
+ 0,
1462
+ 1,
1463
+ 1,
1464
+ 0,
1465
+ 0,
1466
+ 0,
1467
+ 0,
1468
+ 0,
1469
+ 0,
1470
+ 0,
1471
+ 0
1472
+ ],
1473
+ "42": [
1474
+ 0,
1475
+ 0,
1476
+ 0,
1477
+ 0,
1478
+ 0,
1479
+ 0,
1480
+ 0,
1481
+ 1,
1482
+ 0,
1483
+ 0,
1484
+ 1,
1485
+ 0,
1486
+ 0,
1487
+ 0,
1488
+ 1,
1489
+ 0,
1490
+ 1,
1491
+ 1,
1492
+ 1,
1493
+ 1,
1494
+ 1,
1495
+ 1,
1496
+ 0,
1497
+ 0,
1498
+ 0,
1499
+ 0,
1500
+ 0,
1501
+ 0,
1502
+ 1,
1503
+ 0,
1504
+ 0,
1505
+ 0
1506
+ ],
1507
+ "43": [
1508
+ 0,
1509
+ 1,
1510
+ 1,
1511
+ 1,
1512
+ 1,
1513
+ 1,
1514
+ 1,
1515
+ 1,
1516
+ 0,
1517
+ 0,
1518
+ 0,
1519
+ 0,
1520
+ 0,
1521
+ 1,
1522
+ 0,
1523
+ 0,
1524
+ 0,
1525
+ 1,
1526
+ 1,
1527
+ 1,
1528
+ 1,
1529
+ 1,
1530
+ 0,
1531
+ 0,
1532
+ 0,
1533
+ 0,
1534
+ 0,
1535
+ 1,
1536
+ 0,
1537
+ 0,
1538
+ 0,
1539
+ 1
1540
+ ],
1541
+ "44": [
1542
+ 0,
1543
+ 0,
1544
+ 0,
1545
+ 0,
1546
+ 0,
1547
+ 0,
1548
+ 0,
1549
+ 0,
1550
+ 1,
1551
+ 1,
1552
+ 1,
1553
+ 1,
1554
+ 1,
1555
+ 1,
1556
+ 1,
1557
+ 1,
1558
+ 0,
1559
+ 0,
1560
+ 0,
1561
+ 0,
1562
+ 0,
1563
+ 0,
1564
+ 0,
1565
+ 0,
1566
+ 0,
1567
+ 0,
1568
+ 1,
1569
+ 0,
1570
+ 0,
1571
+ 0,
1572
+ 0,
1573
+ 0
1574
+ ],
1575
+ "45": [
1576
+ 0,
1577
+ 0,
1578
+ 0,
1579
+ 1,
1580
+ 0,
1581
+ 0,
1582
+ 1,
1583
+ 1,
1584
+ 1,
1585
+ 1,
1586
+ 1,
1587
+ 1,
1588
+ 1,
1589
+ 0,
1590
+ 1,
1591
+ 1,
1592
+ 0,
1593
+ 1,
1594
+ 1,
1595
+ 1,
1596
+ 0,
1597
+ 1,
1598
+ 1,
1599
+ 1,
1600
+ 1,
1601
+ 1,
1602
+ 1,
1603
+ 1,
1604
+ 1,
1605
+ 1,
1606
+ 1,
1607
+ 1
1608
+ ],
1609
+ "46": [
1610
+ 0,
1611
+ 0,
1612
+ 0,
1613
+ 0,
1614
+ 0,
1615
+ 0,
1616
+ 0,
1617
+ 0,
1618
+ 0,
1619
+ 0,
1620
+ 1,
1621
+ 0,
1622
+ 1,
1623
+ 1,
1624
+ 0,
1625
+ 1,
1626
+ 0,
1627
+ 0,
1628
+ 0,
1629
+ 0,
1630
+ 0,
1631
+ 0,
1632
+ 0,
1633
+ 0,
1634
+ 0,
1635
+ 0,
1636
+ 0,
1637
+ 0,
1638
+ 0,
1639
+ 0,
1640
+ 0,
1641
+ 0
1642
+ ],
1643
+ "47": [
1644
+ 0,
1645
+ 1,
1646
+ 0,
1647
+ 0,
1648
+ 0,
1649
+ 0,
1650
+ 0,
1651
+ 0,
1652
+ 0,
1653
+ 0,
1654
+ 1,
1655
+ 0,
1656
+ 0,
1657
+ 0,
1658
+ 0,
1659
+ 0,
1660
+ 0,
1661
+ 0,
1662
+ 0,
1663
+ 0,
1664
+ 0,
1665
+ 0,
1666
+ 0,
1667
+ 0,
1668
+ 0,
1669
+ 0,
1670
+ 1,
1671
+ 0,
1672
+ 0,
1673
+ 0,
1674
+ 0,
1675
+ 0
1676
+ ]
1677
+ }
1678
  }
configuration_qwen3_moe.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen3MoE model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.modeling_rope_utils import rope_config_validation
19
+ from transformers.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ class Qwen3MoeConfig(PretrainedConfig):
24
+ r"""
25
+ This is the configuration class to store the configuration of a [`Qwen3MoeModel`]. It is used to instantiate a
26
+ Qwen3MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
27
+ with the defaults will yield a similar configuration to that of [Qwen/Qwen3-15B-A2B](https://huggingface.co/Qwen/Qwen3-15B-A2B).
28
+
29
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
30
+ documentation from [`PretrainedConfig`] for more information.
31
+
32
+
33
+ Args:
34
+ vocab_size (`int`, *optional*, defaults to 151936):
35
+ Vocabulary size of the Qwen3MoE model. Defines the number of different tokens that can be represented by the
36
+ `inputs_ids` passed when calling [`Qwen3MoeModel`]
37
+ hidden_size (`int`, *optional*, defaults to 2048):
38
+ Dimension of the hidden representations.
39
+ intermediate_size (`int`, *optional*, defaults to 6144):
40
+ Dimension of the MLP representations.
41
+ num_hidden_layers (`int`, *optional*, defaults to 24):
42
+ Number of hidden layers in the Transformer encoder.
43
+ num_attention_heads (`int`, *optional*, defaults to 32):
44
+ Number of attention heads for each attention layer in the Transformer encoder.
45
+ num_key_value_heads (`int`, *optional*, defaults to 4):
46
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
47
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
48
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
49
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
50
+ by meanpooling all the original heads within that group. For more details, check out [this
51
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
52
+
53
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
54
+ The non-linear activation function (function or string) in the decoder.
55
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
56
+ The maximum sequence length that this model might ever be used with.
57
+ initializer_range (`float`, *optional*, defaults to 0.02):
58
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
59
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
60
+ The epsilon used by the rms normalization layers.
61
+ use_cache (`bool`, *optional*, defaults to `True`):
62
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
63
+ relevant if `config.is_decoder=True`.
64
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
65
+ Whether the model's input and output word embeddings should be tied.
66
+ rope_theta (`float`, *optional*, defaults to 10000.0):
67
+ The base period of the RoPE embeddings.
68
+ rope_scaling (`Dict`, *optional*):
69
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
70
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
71
+ accordingly.
72
+ Expected contents:
73
+ `rope_type` (`str`):
74
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
75
+ 'llama3'], with 'default' being the original RoPE implementation.
76
+ `factor` (`float`, *optional*):
77
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
78
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
79
+ original maximum pre-trained length.
80
+ `original_max_position_embeddings` (`int`, *optional*):
81
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
82
+ pretraining.
83
+ `attention_factor` (`float`, *optional*):
84
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
85
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
86
+ `factor` field to infer the suggested value.
87
+ `beta_fast` (`float`, *optional*):
88
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
89
+ ramp function. If unspecified, it defaults to 32.
90
+ `beta_slow` (`float`, *optional*):
91
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
92
+ ramp function. If unspecified, it defaults to 1.
93
+ `short_factor` (`list[float]`, *optional*):
94
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
95
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
96
+ size divided by the number of attention heads divided by 2
97
+ `long_factor` (`list[float]`, *optional*):
98
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
99
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
100
+ size divided by the number of attention heads divided by 2
101
+ `low_freq_factor` (`float`, *optional*):
102
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
103
+ `high_freq_factor` (`float`, *optional*):
104
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
105
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
106
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
107
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
108
+ Whether to use sliding window attention.
109
+ sliding_window (`int`, *optional*, defaults to 4096):
110
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
111
+ attention_dropout (`float`, *optional*, defaults to 0.0):
112
+ The dropout ratio for the attention probabilities.
113
+ decoder_sparse_step (`int`, *optional*, defaults to 1):
114
+ The frequency of the MoE layer.
115
+ moe_intermediate_size (`int`, *optional*, defaults to 768):
116
+ Intermediate size of the routed expert.
117
+ num_experts_per_tok (`int`, *optional*, defaults to 8):
118
+ Number of selected experts.
119
+ num_experts (`int`, *optional*, defaults to 128):
120
+ Number of routed experts.
121
+ norm_topk_prob (`bool`, *optional*, defaults to `False`):
122
+ Whether to normalize the topk probabilities.
123
+ output_router_logits (`bool`, *optional*, defaults to `False`):
124
+ Whether or not the router logits should be returned by the model. Enabling this will also
125
+ allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
126
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
127
+ The aux loss factor for the total loss.
128
+ mlp_only_layers (`list[int]`, *optional*, defaults to `[]`):
129
+ Indicate which layers use Qwen3MoeMLP rather than Qwen3MoeSparseMoeBlock
130
+ The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
131
+ If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
132
+
133
+ ```python
134
+ >>> from transformers import Qwen3MoeModel, Qwen3MoeConfig
135
+
136
+ >>> # Initializing a Qwen3MoE style configuration
137
+ >>> configuration = Qwen3MoeConfig()
138
+
139
+ >>> # Initializing a model from the Qwen3-15B-A2B" style configuration
140
+ >>> model = Qwen3MoeModel(configuration)
141
+
142
+ >>> # Accessing the model configuration
143
+ >>> configuration = model.config
144
+ ```"""
145
+
146
+ model_type = "qwen3_moe"
147
+ keys_to_ignore_at_inference = ["past_key_values"]
148
+
149
+ # Default tensor parallel plan for base model `Qwen3Moe`
150
+ base_model_tp_plan = {
151
+ "layers.*.self_attn.q_proj": "colwise",
152
+ "layers.*.self_attn.k_proj": "colwise",
153
+ "layers.*.self_attn.v_proj": "colwise",
154
+ "layers.*.self_attn.o_proj": "rowwise",
155
+ "layers.*.mlp.experts.*.gate_proj": "colwise",
156
+ "layers.*.mlp.experts.*.up_proj": "colwise",
157
+ "layers.*.mlp.experts.*.down_proj": "rowwise",
158
+ "layers.*.mlp.gate_proj": "colwise",
159
+ "layers.*.mlp.up_proj": "colwise",
160
+ "layers.*.mlp.down_proj": "rowwise",
161
+ }
162
+ base_model_pp_plan = {
163
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
164
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
165
+ "norm": (["hidden_states"], ["hidden_states"]),
166
+ }
167
+
168
+ def __init__(
169
+ self,
170
+ vocab_size=151936,
171
+ hidden_size=2048,
172
+ intermediate_size=6144,
173
+ num_hidden_layers=24,
174
+ num_attention_heads=32,
175
+ num_key_value_heads=4,
176
+ hidden_act="silu",
177
+ max_position_embeddings=32768,
178
+ initializer_range=0.02,
179
+ rms_norm_eps=1e-6,
180
+ use_cache=True,
181
+ tie_word_embeddings=False,
182
+ rope_theta=10000.0,
183
+ rope_scaling=None,
184
+ attention_bias=False,
185
+ use_sliding_window=False,
186
+ sliding_window=4096,
187
+ attention_dropout=0.0,
188
+ decoder_sparse_step=1,
189
+ moe_intermediate_size=768,
190
+ num_experts_per_tok=8,
191
+ num_experts=128,
192
+ norm_topk_prob=False,
193
+ output_router_logits=False,
194
+ router_aux_loss_coef=0.001,
195
+ mlp_only_layers=None,
196
+ **kwargs,
197
+ ):
198
+ self.vocab_size = vocab_size
199
+ self.max_position_embeddings = max_position_embeddings
200
+ self.hidden_size = hidden_size
201
+ self.intermediate_size = intermediate_size
202
+ self.num_hidden_layers = num_hidden_layers
203
+ self.num_attention_heads = num_attention_heads
204
+ self.use_sliding_window = use_sliding_window
205
+ self.sliding_window = sliding_window if use_sliding_window else None
206
+
207
+ self.num_key_value_heads = num_key_value_heads
208
+ self.hidden_act = hidden_act
209
+ self.initializer_range = initializer_range
210
+ self.rms_norm_eps = rms_norm_eps
211
+ self.use_cache = use_cache
212
+ self.rope_theta = rope_theta
213
+ self.rope_scaling = rope_scaling
214
+ self.attention_bias = attention_bias
215
+ self.attention_dropout = attention_dropout
216
+ # Validate the correctness of rotary position embeddings parameters
217
+ # BC: if there is a 'type' field, move it to 'rope_type'.
218
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
219
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
220
+ rope_config_validation(self)
221
+
222
+ # MoE arguments
223
+ self.decoder_sparse_step = decoder_sparse_step
224
+ self.moe_intermediate_size = moe_intermediate_size
225
+ self.num_experts_per_tok = num_experts_per_tok
226
+ self.num_experts = num_experts
227
+ self.norm_topk_prob = norm_topk_prob
228
+ self.output_router_logits = output_router_logits
229
+ self.router_aux_loss_coef = router_aux_loss_coef
230
+ self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
231
+ self.headwise_config = kwargs.get('headwise_config', None)
232
+
233
+ super().__init__(
234
+ tie_word_embeddings=tie_word_embeddings,
235
+ **kwargs,
236
+ )
237
+
238
+ __all__ = ["Qwen3MoeConfig"]
modeling_qwen3_moe.py CHANGED
@@ -46,9 +46,11 @@ from transformers.utils.deprecation import deprecate_kwarg
46
  from transformers.utils.generic import OutputRecorder, check_model_inputs
47
  from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
48
 
 
49
  from torch.nn.attention.flex_attention import flex_attention, create_block_mask
50
  flex_attention = torch.compile(flex_attention, dynamic=True)
51
 
 
52
  SWA_TOKEN = 8192
53
  SINK_TOKEN = 4
54
 
@@ -107,12 +109,27 @@ def flex_attention_call(
107
  query: torch.Tensor,
108
  key: torch.Tensor,
109
  value: torch.Tensor,
 
110
  ):
111
- S = query.shape[2]
112
- block_mask = create_block_mask(sink_mask, 1, 1, S, S, device=query.device)
113
- attn_output: torch.Tensor = flex_attention(query, key, value, block_mask=block_mask)
114
-
115
- return attn_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  def flex_attention_forward(
118
  module: nn.Module,
@@ -124,7 +141,6 @@ def flex_attention_forward(
124
  dropout: float = 0.0,
125
  **kwargs: Unpack[TransformersKwargs],
126
  ):
127
-
128
  seq_len, q_head_num = query.shape[2], query.shape[1]
129
  kv_head_num = key.shape[1]
130
 
@@ -132,7 +148,7 @@ def flex_attention_forward(
132
  key = repeat_kv(key, n_repeat)
133
  value = repeat_kv(value, n_repeat)
134
 
135
- attn_output = flex_attention_call(query, key, value)
136
 
137
  # return attn_output, None
138
  return attn_output.transpose(1, 2), None
@@ -191,6 +207,10 @@ class Qwen3MoeAttention(nn.Module):
191
  self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
192
  self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
193
  self.sliding_window = getattr(config, "sliding_window", None)
 
 
 
 
194
 
195
  @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
196
  def forward(
@@ -217,11 +237,12 @@ class Qwen3MoeAttention(nn.Module):
217
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
218
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
219
 
220
- import pdb; pdb.set_trace()
221
-
222
  attention_interface: Callable = eager_attention_forward
223
- if self.config._attn_implementation == "headwise":
 
224
  attention_interface = flex_attention_forward
 
 
225
  elif self.config._attn_implementation != "eager":
226
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
227
 
 
46
  from transformers.utils.generic import OutputRecorder, check_model_inputs
47
  from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
48
 
49
+ from flash_attn import flash_attn_func
50
  from torch.nn.attention.flex_attention import flex_attention, create_block_mask
51
  flex_attention = torch.compile(flex_attention, dynamic=True)
52
 
53
+ SEQ_THREAD = 16384
54
  SWA_TOKEN = 8192
55
  SINK_TOKEN = 4
56
 
 
109
  query: torch.Tensor,
110
  key: torch.Tensor,
111
  value: torch.Tensor,
112
+ **kwargs: Unpack[TransformersKwargs],
113
  ):
114
+ qlen = query.shape[2]
115
+ kvlen = key.shape[2]
116
+ output = torch.empty_like(query)
117
+ retrieval_heads = kwargs.pop('retrieval_heads')
118
+ non_retrieval_heads = kwargs.pop('non_retrieval_heads')
119
+ if retrieval_heads.sum():
120
+ attn_output: torch.Tensor = flash_attn_func(query.transpose(1, 2),
121
+ key.transpose(1, 2),
122
+ value.transpose(1, 2),
123
+ causal=True).transpose(1, 2)
124
+ output[:, retrieval_heads, :, : ] = attn_output[:, retrieval_heads, :, : ]
125
+ if non_retrieval_heads.sum():
126
+ block_sink_mask = create_block_mask(sink_mask, 1, 1, qlen, kvlen, device=query.device)
127
+ attn_output: torch.Tensor = flex_attention(query,
128
+ key,
129
+ value,
130
+ block_mask=block_sink_mask)
131
+ output[:, non_retrieval_heads, :, : ] = attn_output[:, non_retrieval_heads, :, : ]
132
+ return output
133
 
134
  def flex_attention_forward(
135
  module: nn.Module,
 
141
  dropout: float = 0.0,
142
  **kwargs: Unpack[TransformersKwargs],
143
  ):
 
144
  seq_len, q_head_num = query.shape[2], query.shape[1]
145
  kv_head_num = key.shape[1]
146
 
 
148
  key = repeat_kv(key, n_repeat)
149
  value = repeat_kv(value, n_repeat)
150
 
151
+ attn_output = flex_attention_call(query, key, value, **kwargs)
152
 
153
  # return attn_output, None
154
  return attn_output.transpose(1, 2), None
 
207
  self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
208
  self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
209
  self.sliding_window = getattr(config, "sliding_window", None)
210
+ mask_list = config.headwise_config[str(layer_idx)]
211
+ mask_tensor = torch.tensor(mask_list)
212
+ self.non_retrieval_heads = (mask_tensor == 0)
213
+ self.retrieval_heads = (mask_tensor == 1)
214
 
215
  @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
216
  def forward(
 
237
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
238
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
239
 
 
 
240
  attention_interface: Callable = eager_attention_forward
241
+ # if self.config._attn_implementation == "headwise":
242
+ if query_states.size(2) > SEQ_THREAD:
243
  attention_interface = flex_attention_forward
244
+ kwargs['non_retrieval_heads'] = self.non_retrieval_heads
245
+ kwargs['retrieval_heads'] = self.retrieval_heads
246
  elif self.config._attn_implementation != "eager":
247
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
248