File size: 65,141 Bytes
19f7733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Any, Callable, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn

from transformers.activations import ACT2FN
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import (
    MoeCausalLMOutputWithPast,
    MoeModelOutputWithPast,
    ModelOutput,
)
from transformers.modeling_utils import PreTrainedModel, ALL_ATTENTION_FUNCTIONS
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.masking_utils import (
    create_causal_mask,
    create_sliding_window_causal_mask,
)
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs
from transformers.cache_utils import Cache, DynamicCache
from transformers.integrations import use_kernel_forward_from_hub


try:
    from .configuration_trinity_vlm import AfmoeConfig, TrinityVLMConfig
except Exception:
    from configuration_trinity_vlm import AfmoeConfig, TrinityVLMConfig


def _compute_default_rope_parameters(
    config=None,
    device: torch.device | None = None,
    seq_len: int | None = None,
    layer_type: str | None = None,
) -> tuple[torch.Tensor, float]:
    del seq_len, layer_type
    if config is None:
        raise ValueError("config is required to compute default RoPE parameters.")

    base = getattr(config, "rope_theta", 10000.0)
    partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
    head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
    dim = int(head_dim * partial_rotary_factor)
    inv_freq = 1.0 / (
        base
        ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
    )
    return inv_freq, 1.0


if "default" not in ROPE_INIT_FUNCTIONS:
    ROPE_INIT_FUNCTIONS["default"] = _compute_default_rope_parameters

class AfmoeRotaryEmbedding(nn.Module):

    def __init__(self, config: AfmoeConfig, device=None):
        super().__init__()
        # BC: "rope_type" was originally "type"
        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
        else:
            self.rope_type = "default"
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    def _dynamic_frequency_update(self, position_ids, device):
        """
        dynamic RoPE layers should recompute `inv_freq` in the following situations:
        1 - growing beyond the cached sequence length (allow scaling)
        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
        """
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_seq_len_cached:  # growth
            inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
            self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: may break with compilation
            self.max_seq_len_cached = seq_len

        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # reset
            # This .to() is needed if the model has been moved to a device after being initialized (because
            # the buffer is automatically moved, but not the original copy)
            self.original_inv_freq = self.original_inv_freq.to(device)
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
            self.max_seq_len_cached = self.original_max_seq_len

    def compute_default_rope_parameters(
        self,
        config=None,
        device: torch.device | None = None,
        seq_len: int | None = None,
        layer_type: str | None = None,
    ) -> tuple[torch.Tensor, float]:
        return _compute_default_rope_parameters(
            config=config or self.config,
            device=device,
            seq_len=seq_len,
            layer_type=layer_type,
        )

    @torch.no_grad()
    def forward(self, x, position_ids):
        if "dynamic" in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

        # Core RoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()

        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
        cos = cos * self.attention_scaling
        sin = sin * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

@use_kernel_forward_from_hub("RMSNorm")
class AfmoeRMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float):
        """
        AfmoeRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"



def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
        query.dtype
    )
    attn_weights = nn.functional.dropout(
        attn_weights, p=dropout, training=module.training
    )
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


class AfmoeMLP(nn.Module):
    def __init__(self, config, intermediate_size=None):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = intermediate_size or config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


class AfmoeTokenChoiceRouter(nn.Module):
    """Token-choice top-K router for MoE routing."""

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.top_k = config.num_experts_per_tok
        self.num_experts = config.num_experts
        self.score_func = config.score_func
        self.route_norm = config.route_norm
        self.route_scale = config.route_scale
        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)     

    def forward(self, hidden_states, expert_bias: torch.Tensor | None):
        _, _, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)

        scores = self.gate(hidden_states)

        # Apply scoring function in float32 for stability
        if self.score_func == "sigmoid":
            scores = torch.sigmoid(scores.to(torch.float32))
        else:
            scores = F.softmax(scores.to(torch.float32), dim=-1)

        if expert_bias is not None:
            _, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
            top_scores = scores.gather(dim=1, index=selected_experts)
        else:
            top_scores, selected_experts = torch.topk(scores, k=self.top_k, dim=1)

        # Normalize weights if using sigmoid
        if self.score_func == "sigmoid" and self.route_norm:
            denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
            top_scores = top_scores / denominator

        top_scores = top_scores * self.route_scale
        return top_scores, selected_experts


def _can_use_grouped_mm(hidden_states: torch.Tensor) -> bool:
    return (
        hidden_states.is_cuda
        and hidden_states.dtype == torch.bfloat16
        and hasattr(F, "grouped_mm")
    )


def _router_forward(
    router: nn.Module,
    hidden_states: torch.Tensor,
    expert_bias: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    _, _, hidden_dim = hidden_states.shape
    hidden_states_flat = hidden_states.view(-1, hidden_dim)
    router_logits = router.gate(hidden_states_flat)

    if router.score_func == "sigmoid":
        router_probs = torch.sigmoid(router_logits.to(torch.float32))
    else:
        router_probs = F.softmax(router_logits.to(torch.float32), dim=-1)

    if expert_bias is not None:
        _, selected_experts = torch.topk(router_probs + expert_bias, k=router.top_k, dim=1)
        top_scores = router_probs.gather(dim=1, index=selected_experts)
    else:
        top_scores, selected_experts = torch.topk(router_probs, k=router.top_k, dim=1)

    if router.score_func == "sigmoid" and router.route_norm:
        denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
        top_scores = top_scores / denominator

    top_scores = top_scores * router.route_scale
    return hidden_states_flat, router_logits, router_probs, top_scores, selected_experts


def _router_aux_loss(
    router_probs: torch.Tensor,
    selected_experts: torch.Tensor,
    *,
    num_experts: int,
) -> torch.Tensor:
    selected_flat = selected_experts.reshape(-1)
    top_k = max(1, selected_experts.shape[-1])
    token_count = max(1, selected_experts.shape[0])
    tokens_per_expert = torch.bincount(
        selected_flat,
        minlength=num_experts,
    ).to(torch.float32) / float(token_count * top_k)
    router_prob_per_expert = router_probs.mean(dim=0)
    return num_experts * torch.sum(tokens_per_expert * router_prob_per_expert)


def _get_grouped_projection_weights(
    moe_layer: nn.Module,
    expert_ids: torch.Tensor,
    *,
    projection_name: str,
) -> torch.Tensor:
    if expert_ids.numel() == 0:
        raise ValueError("Cannot select grouped weights for an empty expert set.")

    packed_weights = getattr(moe_layer, f"packed_{projection_name}", None)
    if isinstance(packed_weights, nn.Parameter):
        if expert_ids.numel() == packed_weights.shape[0]:
            full_expert_ids = torch.arange(
                packed_weights.shape[0],
                device=expert_ids.device,
                dtype=expert_ids.dtype,
            )
            if torch.equal(expert_ids, full_expert_ids):
                return packed_weights
        return packed_weights.index_select(0, expert_ids).contiguous()

    experts = getattr(moe_layer, "experts", None)
    if experts is None:
        raise ValueError(f"Layer has neither packed experts nor per-expert modules for {projection_name}.")

    projection_weights = []
    requires_grad = False
    for expert_id in expert_ids.tolist():
        weight = getattr(experts[expert_id], projection_name).weight
        requires_grad = requires_grad or weight.requires_grad
        projection_weights.append(weight.transpose(0, 1))

    if not requires_grad:
        projection_weights = [weight.detach() for weight in projection_weights]

    return torch.stack(projection_weights, dim=0).contiguous()


def _dense_packed_moe_forward(
    moe_layer: nn.Module,
    routed_input: torch.Tensor,
    token_to_expert: torch.Tensor,
) -> torch.Tensor:
    routed_output = torch.zeros(
        routed_input.shape[0],
        moe_layer.config.hidden_size,
        device=routed_input.device,
        dtype=routed_input.dtype,
    )

    packed_gate_proj = getattr(moe_layer, "packed_gate_proj", None)
    packed_up_proj = getattr(moe_layer, "packed_up_proj", None)
    packed_down_proj = getattr(moe_layer, "packed_down_proj", None)
    if not all(
        isinstance(weight, nn.Parameter)
        for weight in (packed_gate_proj, packed_up_proj, packed_down_proj)
    ):
        for expert_id in range(moe_layer.config.num_experts):
            mask = token_to_expert == expert_id
            if not mask.any():
                continue
            expert_input = routed_input[mask]
            expert_out = moe_layer.experts[expert_id](expert_input)
            routed_output[mask] = expert_out
        return routed_output

    act_fn = ACT2FN[moe_layer.config.hidden_act]
    for expert_id in range(moe_layer.config.num_experts):
        mask = token_to_expert == expert_id
        if not mask.any():
            continue
        expert_input = routed_input[mask]
        gate_proj = F.linear(expert_input, packed_gate_proj[expert_id].transpose(0, 1))
        up_proj = F.linear(expert_input, packed_up_proj[expert_id].transpose(0, 1))
        activated = act_fn(gate_proj) * up_proj
        expert_out = F.linear(activated, packed_down_proj[expert_id].transpose(0, 1))
        routed_output[mask] = expert_out
    return routed_output


def _accumulate_routed_output(
    shared_output: torch.Tensor,
    routed_output: torch.Tensor,
    top_scores_sorted: torch.Tensor,
    token_indices_sorted: torch.Tensor,
) -> torch.Tensor:
    output = shared_output.to(torch.float32)
    if routed_output.numel() == 0:
        return output

    hidden_dim = routed_output.shape[-1]
    bytes_per_row = max(1, hidden_dim * 4)
    target_chunk_bytes = 16 * 1024 * 1024
    rows_per_chunk = max(1, target_chunk_bytes // bytes_per_row)

    for start in range(0, routed_output.shape[0], rows_per_chunk):
        end = min(start + rows_per_chunk, routed_output.shape[0])
        weighted_chunk = routed_output[start:end].to(torch.float32)
        weighted_chunk.mul_(top_scores_sorted[start:end].unsqueeze(-1))
        output.index_add_(0, token_indices_sorted[start:end], weighted_chunk)

    return output

class AfmoeMoE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.use_packed_experts = bool(getattr(config, "packed_experts", False))
        self.router = AfmoeTokenChoiceRouter(config)

        self.shared_experts = None
        if config.num_shared_experts > 0:
            self.shared_experts = AfmoeMLP(
                config, config.moe_intermediate_size * config.num_shared_experts
            )
        self.expert_bias = nn.Parameter(torch.zeros(config.num_experts, dtype=torch.float32), requires_grad=False)
        if self.use_packed_experts:
            self.experts = None
            self.packed_gate_proj = nn.Parameter(
                torch.empty(config.num_experts, config.hidden_size, config.moe_intermediate_size)
            )
            self.packed_up_proj = nn.Parameter(
                torch.empty(config.num_experts, config.hidden_size, config.moe_intermediate_size)
            )
            self.packed_down_proj = nn.Parameter(
                torch.empty(config.num_experts, config.moe_intermediate_size, config.hidden_size)
            )
            self.reset_parameters()
        else:
            self.experts = nn.ModuleList(
                [
                    AfmoeMLP(config, intermediate_size=config.moe_intermediate_size)
                    for _ in range(config.num_experts)
                ]
            )

    def reset_parameters(self):
        std = float(getattr(self.config, "initializer_range", 0.02))
        nn.init.normal_(self.packed_gate_proj, mean=0.0, std=std)
        nn.init.normal_(self.packed_up_proj, mean=0.0, std=std)
        nn.init.normal_(self.packed_down_proj, mean=0.0, std=std)
        with torch.no_grad():
            self.expert_bias.zero_()

    def forward(self, hidden_states):
        batch_size, seq_len, hidden_dim = hidden_states.shape
        hidden_states_flat, _router_logits, router_probs, top_scores, selected_experts = _router_forward(
            self.router,
            hidden_states,
            self.expert_bias,
        )

        if self.shared_experts is not None:
            shared_output = self.shared_experts(hidden_states_flat)
        else:
            shared_output = torch.zeros_like(hidden_states_flat)

        token_indices_sorted = torch.argsort(selected_experts.view(-1), stable=True)
        top_scores_sorted = top_scores.view(-1)[token_indices_sorted]
        token_to_expert = selected_experts.view(-1)[token_indices_sorted]
        token_indices_sorted = token_indices_sorted // self.config.num_experts_per_tok
        token_indices_expanded = token_indices_sorted.unsqueeze(-1).expand(-1, hidden_dim)
        routed_input = torch.gather(hidden_states_flat, dim=0, index=token_indices_expanded).contiguous()

        routed_output: torch.Tensor | None = None
        use_grouped_mm = bool(getattr(self.config, "enable_grouped_moe", True)) and _can_use_grouped_mm(
            routed_input
        )
        if use_grouped_mm:
            expert_counts = torch.bincount(
                token_to_expert,
                minlength=self.config.num_experts,
            )
            grouped_offsets = torch.cumsum(
                expert_counts,
                dim=0,
                dtype=torch.int32,
            )
            packed_gate_proj = getattr(self, "packed_gate_proj", None)
            packed_up_proj = getattr(self, "packed_up_proj", None)
            packed_down_proj = getattr(self, "packed_down_proj", None)

            if all(
                isinstance(weight, nn.Parameter)
                for weight in (packed_gate_proj, packed_up_proj, packed_down_proj)
            ):
                gate_weights = packed_gate_proj
                up_weights = packed_up_proj
                down_weights = packed_down_proj
            else:
                active_expert_ids = torch.nonzero(expert_counts > 0, as_tuple=False).flatten()
                if active_expert_ids.numel() == 0:
                    routed_output = torch.zeros_like(routed_input)
                    gate_weights = up_weights = down_weights = None
                else:
                    grouped_offsets = torch.cumsum(
                        expert_counts.index_select(0, active_expert_ids),
                        dim=0,
                        dtype=torch.int32,
                    )
                    gate_weights = _get_grouped_projection_weights(
                        self,
                        active_expert_ids,
                        projection_name="gate_proj",
                    )
                    up_weights = _get_grouped_projection_weights(
                        self,
                        active_expert_ids,
                        projection_name="up_proj",
                    )
                    down_weights = _get_grouped_projection_weights(
                        self,
                        active_expert_ids,
                        projection_name="down_proj",
                    )

            if routed_output is None:
                gate_proj = F.grouped_mm(routed_input, gate_weights, offs=grouped_offsets)
                up_proj = F.grouped_mm(routed_input, up_weights, offs=grouped_offsets)
                activated = ACT2FN[self.config.hidden_act](gate_proj) * up_proj
                routed_output = F.grouped_mm(activated, down_weights, offs=grouped_offsets)
        else:
            routed_output = _dense_packed_moe_forward(
                self,
                routed_input,
                token_to_expert,
            )

        if routed_output is None:
            raise RuntimeError("MoE forward did not produce routed output.")

        if use_grouped_mm:
            del expert_counts, grouped_offsets
            if "active_expert_ids" in locals():
                del active_expert_ids
            if "gate_weights" in locals():
                del gate_weights, up_weights, down_weights
            if "gate_proj" in locals():
                del gate_proj, up_proj, activated

        output = _accumulate_routed_output(
            shared_output=shared_output,
            routed_output=routed_output,
            top_scores_sorted=top_scores_sorted,
            token_indices_sorted=token_indices_sorted,
        )

        aux_loss_coef = float(getattr(self.config, "router_aux_loss_coef", 0.0) or 0.0)
        self._last_router_aux_loss = None
        if aux_loss_coef > 0.0:
            self._last_router_aux_loss = _router_aux_loss(
                router_probs,
                selected_experts,
                num_experts=self.config.num_experts,
            )

        self._last_router_logits = None
        if getattr(self.config, "output_router_logits", False):
            self._last_router_logits = router_probs.view(batch_size, seq_len, self.config.num_experts)

        return output.to(hidden_states.dtype).view(batch_size, seq_len, hidden_dim)


class AfmoeAttention(nn.Module):
    """Multi-headed attention with local/global pattern and gating."""

    def __init__(self, config: AfmoeConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        self.num_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads

        self.scaling = self.head_dim**-0.5
        self.attention_dropout = config.attention_dropout
        self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
        self.sliding_window = config.sliding_window if self.is_local_attention else None

        self.q_proj = nn.Linear(
            config.hidden_size, self.num_heads * self.head_dim, bias=False
        )
        self.k_proj = nn.Linear(
            config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
        )
        self.v_proj = nn.Linear(
            config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
        )
        self.o_proj = nn.Linear(
            self.num_heads * self.head_dim, config.hidden_size, bias=False
        )

        self.q_norm = AfmoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = AfmoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)

        self.gate_proj = nn.Linear(
            config.hidden_size, self.num_heads * self.head_dim, bias=False
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:   

        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape)
        key_states = self.k_proj(hidden_states).view(hidden_shape)
        value_states = self.v_proj(hidden_states).view(hidden_shape)
        gate_states = self.gate_proj(hidden_states)

        query_states = self.q_norm(query_states)
        key_states = self.k_norm(key_states)
        
        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        if self.is_local_attention:
            cos, sin = position_embeddings
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            cache_kwargs = {"cache_position": cache_position}
            key_states, value_states = past_key_value.update(
                key_states,
                value_states,
                self.layer_idx,
                cache_kwargs,
            )

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[
                self.config._attn_implementation
            ]

        output, _ = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask=attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            sliding_window=self.sliding_window,
            **kwargs,
        )

        output = output.view(*input_shape, -1).contiguous()
        output = output * F.sigmoid(gate_states)
        return self.o_proj(output)


class AfmoeDecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: AfmoeConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.layer_idx = layer_idx

        self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx)
        self.attention_type = config.layer_types[layer_idx]

        # Dual normalization for attention
        self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # Dual normalization for FFN
        self.pre_mlp_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_mlp_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # MoE or dense FFN
        self.moe_enabled = layer_idx >= config.num_dense_layers
        if self.moe_enabled:
            self.mlp = AfmoeMoE(config)
        else:
            self.mlp = AfmoeMLP(config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.FloatTensor:
        residual = hidden_states

        # Self Attention with dual normalization
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        # FFN with dual normalization
        residual = hidden_states
        hidden_states = self.pre_mlp_layernorm(hidden_states)

        if self.moe_enabled:
            hidden_states = self.mlp(hidden_states)
        else:
            hidden_states = self.mlp(hidden_states)

        hidden_states = self.post_mlp_layernorm(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


class AfmoePreTrainedModel(PreTrainedModel):
    config_class = AfmoeConfig
    base_model_prefix = "model"
    _no_split_modules = ["AfmoeDecoderLayer"]
    _skip_keys_device_placement = ["past_key_values"]
    _keep_in_fp32_modules = [
        "input_layernorm",
        "post_attention_layernorm",
        "pre_mlp_layernorm",
        "post_mlp_layernorm",
        "q_norm",
        "k_norm",
        "norm",
    ]
    _supports_sdpa = True
    _supports_attention_backend = True
    supports_gradient_checkpointing = True


class AfmoeModel(AfmoePreTrainedModel):
    _no_split_modules = ["AfmoeDecoderLayer"]

    def __init__(self, config: AfmoeConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(
            config.vocab_size, config.hidden_size, self.padding_idx
        )
        self.layers = nn.ModuleList(
            [
                AfmoeDecoderLayer(config, layer_idx)
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
        self.norm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = AfmoeRotaryEmbedding(config=config)
        self.gradient_checkpointing = False

        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value


    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[list[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> MoeModelOutputWithPast:
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You must specify at least one of input_ids or inputs_embeds.")

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache()

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if cache_position is None:
            past_seen_tokens = (
                past_key_values.get_seq_length() if past_key_values is not None else 0
            )
            cache_position = torch.arange(
                past_seen_tokens,
                past_seen_tokens + inputs_embeds.shape[1],
                device=inputs_embeds.device,
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        if not isinstance(causal_mask_mapping := attention_mask, dict):
            mask_kwargs = {
                "config": self.config,
                "inputs_embeds": inputs_embeds,
                "attention_mask": attention_mask,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
            }
            causal_mask_mapping = {
                "full_attention": create_causal_mask(**mask_kwargs),
                "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
            }

        hidden_states = inputs_embeds

        if self.config.mup_enabled:
            hidden_states = hidden_states * (self.config.hidden_size**0.5)

        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        router_aux_losses = []
        collect_router_logits = bool(getattr(self.config, "output_router_logits", False))
        router_logits = [] if collect_router_logits else None

        for decoder_layer in self.layers:
            hidden_states = decoder_layer(
                hidden_states,
                attention_mask=causal_mask_mapping[decoder_layer.attention_type],
                position_ids=position_ids,
                past_key_value=past_key_values,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **kwargs,
            )

            if not getattr(decoder_layer, "moe_enabled", False):
                continue

            layer_aux_loss = getattr(decoder_layer.mlp, "_last_router_aux_loss", None)
            if layer_aux_loss is not None:
                router_aux_losses.append(layer_aux_loss)
            if router_logits is not None:
                layer_router_logits = getattr(decoder_layer.mlp, "_last_router_logits", None)
                if layer_router_logits is not None:
                    router_logits.append(layer_router_logits)
            decoder_layer.mlp._last_router_aux_loss = None
            decoder_layer.mlp._last_router_logits = None

        hidden_states = self.norm(hidden_states)
        self._last_router_aux_loss = (
            torch.stack(router_aux_losses).mean() if router_aux_losses else None
        )

        return MoeModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
            router_logits=tuple(router_logits) if router_logits else None,
        )


class AfmoeForCausalLM(AfmoePreTrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]
    _tp_plan = {"lm_head": "colwise_rep"}
    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

    def __init__(self, config):
        super().__init__(config)
        self.model = AfmoeModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        token_type_ids: Optional[torch.Tensor] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
        del token_type_ids
        outputs: MoeModelOutputWithPast = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = outputs.last_hidden_state
        aux_loss = getattr(self.model, "_last_router_aux_loss", None)
        aux_loss_coef = float(getattr(self.config, "router_aux_loss_coef", 0.0) or 0.0)
        slice_indices = (
            slice(-logits_to_keep, None)
            if isinstance(logits_to_keep, int)
            else logits_to_keep
        )
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)

        if loss is not None and aux_loss is not None and aux_loss_coef > 0.0:
            loss = loss + (aux_loss * aux_loss_coef)

        return MoeCausalLMOutputWithPast(
            loss=loss,
            aux_loss=aux_loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            router_logits=outputs.router_logits,
        )

@dataclass(frozen=True)
class MoondreamVisionConfig:
    enc_dim: int = 1152
    enc_patch_size: int = 14
    enc_n_layers: int = 27
    enc_ff_dim: int = 4304
    enc_n_heads: int = 16
    proj_out_dim: int = 2048
    crop_size: int = 378
    in_channels: int = 3
    max_crops: int = 12
    overlap_margin: int = 4
    proj_inner_dim: int = 8192

    @property
    def image_seq_len(self) -> int:
        return (self.crop_size // self.enc_patch_size) ** 2


def select_tiling(height: int, width: int, crop_size: int, max_crops: int) -> tuple[int, int]:
    if height <= crop_size or width <= crop_size:
        return (1, 1)

    min_h = math.ceil(height / crop_size)
    min_w = math.ceil(width / crop_size)

    if min_h * min_w > max_crops:
        ratio = math.sqrt(max_crops / (min_h * min_w))
        return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio)))

    h_tiles = math.floor(math.sqrt(max_crops * height / width))
    w_tiles = math.floor(math.sqrt(max_crops * width / height))

    h_tiles = max(h_tiles, min_h)
    w_tiles = max(w_tiles, min_w)

    if h_tiles * w_tiles > max_crops:
        if w_tiles > h_tiles:
            w_tiles = math.floor(max_crops / h_tiles)
        else:
            h_tiles = math.floor(max_crops / w_tiles)

    return (max(1, h_tiles), max(1, w_tiles))


def overlap_crop_image(
    image: np.ndarray,
    overlap_margin: int,
    max_crops: int,
    base_size: tuple[int, int] = (378, 378),
    patch_size: int = 14,
) -> tuple[np.ndarray, tuple[int, int]]:
    original_h, original_w = image.shape[:2]
    margin_pixels = patch_size * overlap_margin
    total_margin_pixels = margin_pixels * 2
    crop_patches = base_size[0] // patch_size
    crop_window_patches = crop_patches - (2 * overlap_margin)
    crop_window_size = crop_window_patches * patch_size

    tiling = select_tiling(
        original_h - total_margin_pixels,
        original_w - total_margin_pixels,
        crop_window_size,
        max_crops,
    )

    n_crops = tiling[0] * tiling[1] + 1
    crops = np.zeros((n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8)

    target_size = (
        tiling[0] * crop_window_size + total_margin_pixels,
        tiling[1] * crop_window_size + total_margin_pixels,
    )
    pil_img = Image.fromarray(image)
    resized = pil_img.resize(
        (int(target_size[1]), int(target_size[0])),
        resample=Image.Resampling.LANCZOS,
    )
    image = np.asarray(resized)

    global_pil = pil_img.resize(
        (int(base_size[1]), int(base_size[0])),
        resample=Image.Resampling.LANCZOS,
    )
    crops[0] = np.asarray(global_pil)

    for i in range(tiling[0]):
        for j in range(tiling[1]):
            y0 = i * crop_window_size
            x0 = j * crop_window_size
            y_end = min(y0 + base_size[0], image.shape[0])
            x_end = min(x0 + base_size[1], image.shape[1])
            crop_region = image[y0:y_end, x0:x_end]
            crops[1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1]] = crop_region

    return crops, tiling


@torch.compiler.disable
def reconstruct_from_crops(
    crops: torch.Tensor,
    tiling: tuple[int, int],
    overlap_margin: int,
    patch_size: int = 14,
) -> torch.Tensor:
    tiling_h, tiling_w = tiling
    crop_height, crop_width = crops[0].shape[:2]
    margin_pixels = overlap_margin * patch_size
    output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels
    output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels

    reconstructed = torch.zeros(
        (output_h, output_w, crops[0].shape[2]),
        device=crops[0].device,
        dtype=crops[0].dtype,
    )

    for i, crop in enumerate(crops):
        tile_y = i // tiling_w
        tile_x = i % tiling_w
        x_start = 0 if tile_x == 0 else margin_pixels
        x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels
        y_start = 0 if tile_y == 0 else margin_pixels
        y_end = crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels
        out_x = tile_x * (crop_width - 2 * margin_pixels)
        out_y = tile_y * (crop_height - 2 * margin_pixels)
        reconstructed[
            out_y + y_start : out_y + y_end,
            out_x + x_start : out_x + x_end,
        ] = crop[y_start:y_end, x_start:x_end]

    return reconstructed


@torch.compiler.disable
def prepare_crops(
    image: Image.Image,
    config: MoondreamVisionConfig,
    device: torch.device,
    dtype: torch.dtype,
) -> tuple[torch.Tensor, tuple[int, int]]:
    np_image = np.array(image.convert("RGB"))
    crops, tiling = overlap_crop_image(
        np_image,
        max_crops=config.max_crops,
        overlap_margin=config.overlap_margin,
        base_size=(config.crop_size, config.crop_size),
        patch_size=config.enc_patch_size,
    )
    crops = np.transpose(crops, (0, 3, 1, 2))
    crops_tensor = torch.from_numpy(crops).to(device=device, dtype=dtype)
    crops_tensor = crops_tensor.div_(255.0).sub_(0.5).div_(0.5)
    return crops_tensor, tiling


def create_patches(x: torch.Tensor, patch_size: int) -> torch.Tensor:
    batch, channels, height, width = x.shape
    x = x.reshape(batch, channels, height // patch_size, patch_size, width // patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5)
    x = x.reshape(batch, (height // patch_size) * (width // patch_size), channels * patch_size * patch_size)
    return x


class MoondreamAttention(nn.Module):
    def __init__(self, dim: int, n_heads: int, dtype: torch.dtype) -> None:
        super().__init__()
        self.n_heads = n_heads
        self.qkv = nn.Linear(dim, 3 * dim, dtype=dtype)
        self.proj = nn.Linear(dim, dim, dtype=dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch, seq_len, dim = x.shape
        head_dim = dim // self.n_heads
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        q = q.view(batch, seq_len, self.n_heads, head_dim).transpose(1, 2)
        k = k.view(batch, seq_len, self.n_heads, head_dim).transpose(1, 2)
        v = v.view(batch, seq_len, self.n_heads, head_dim).transpose(1, 2)
        out = F.scaled_dot_product_attention(q, k, v)
        out = out.transpose(1, 2).reshape(batch, seq_len, dim)
        return self.proj(out)


class MoondreamMLP(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, dtype: torch.dtype) -> None:
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim, dtype=dtype)
        self.fc2 = nn.Linear(hidden_dim, out_dim, dtype=dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.gelu(self.fc1(x), approximate="tanh")
        return self.fc2(x)


class MoondreamVisionBlock(nn.Module):
    def __init__(self, config: MoondreamVisionConfig, dtype: torch.dtype) -> None:
        super().__init__()
        self.ln1 = nn.LayerNorm(config.enc_dim, dtype=dtype)
        self.attn = MoondreamAttention(config.enc_dim, config.enc_n_heads, dtype)
        self.ln2 = nn.LayerNorm(config.enc_dim, dtype=dtype)
        self.mlp = MoondreamMLP(config.enc_dim, config.enc_ff_dim, config.enc_dim, dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class MoondreamVisionTower(nn.Module):
    def __init__(
        self,
        config: MoondreamVisionConfig | None = None,
        dtype: torch.dtype = torch.bfloat16,
    ) -> None:
        super().__init__()
        self.config = config or MoondreamVisionConfig()
        self.patch_emb = nn.Linear(
            self.config.enc_patch_size * self.config.enc_patch_size * self.config.in_channels,
            self.config.enc_dim,
            dtype=dtype,
        )
        self.blocks = nn.ModuleList(
            [MoondreamVisionBlock(self.config, dtype=dtype) for _ in range(self.config.enc_n_layers)]
        )
        self.post_ln = nn.LayerNorm(self.config.enc_dim, dtype=dtype)
        self.proj_mlp = MoondreamMLP(
            self.config.enc_dim * 2,
            self.config.proj_inner_dim,
            self.config.proj_out_dim,
            dtype,
        )
        self.pos_emb = nn.Parameter(
            torch.zeros(1, self.config.image_seq_len, self.config.enc_dim, dtype=dtype)
        )

    @property
    def image_seq_len(self) -> int:
        return self.config.image_seq_len

    def encode_crops(self, inputs_bchw: torch.Tensor) -> torch.Tensor:
        x = create_patches(inputs_bchw, self.config.enc_patch_size)
        x = self.patch_emb(x)
        x = x + self.pos_emb
        for block in self.blocks:
            x = block(x)
        return self.post_ln(x)

    def project_features(self, global_features: torch.Tensor, reconstructed: torch.Tensor) -> torch.Tensor:
        reconstructed = reconstructed.permute(2, 0, 1)
        reconstructed = F.adaptive_avg_pool2d(
            reconstructed,
            output_size=(self.config.enc_n_layers, self.config.enc_n_layers),
        )
        reconstructed = reconstructed.permute(1, 2, 0).reshape(self.image_seq_len, self.config.enc_dim)
        return self.proj_mlp(torch.cat([global_features, reconstructed], dim=-1))

    def encode_image(self, image: Image.Image) -> torch.Tensor:
        if not isinstance(image, Image.Image):
            raise TypeError(f"Expected PIL image, got {type(image)!r}")

        device = self.pos_emb.device
        dtype = self.pos_emb.dtype
        crops, tiling = prepare_crops(image, self.config, device=device, dtype=dtype)
        outputs = self.encode_crops(crops)
        global_features = outputs[0]
        local_features = outputs[1:].view(
            -1,
            self.config.enc_n_layers,
            self.config.enc_n_layers,
            self.config.enc_dim,
        )
        reconstructed = reconstruct_from_crops(
            local_features,
            tiling,
            patch_size=1,
            overlap_margin=self.config.overlap_margin,
        )
        return self.project_features(global_features, reconstructed)

    def encode_images(self, images: list[Image.Image]) -> torch.Tensor:
        encoded = [self.encode_image(image) for image in images]
        return torch.stack(encoded, dim=0)

def build_image_token_span(
    *,
    image_start_token_id: int | None,
    image_token_id: int | None,
    image_end_token_id: int | None,
    image_seq_len: int,
    bos_token_id: int | None = None,
) -> list[int]:
    if image_start_token_id is None:
        raise ValueError("image_start_token_id is not configured.")
    if image_token_id is None:
        raise ValueError("image_token_id is not configured.")
    if image_end_token_id is None:
        raise ValueError("image_end_token_id is not configured.")

    token_ids: list[int] = []
    if bos_token_id is not None:
        token_ids.append(bos_token_id)
    token_ids.append(image_start_token_id)
    token_ids.extend([image_token_id] * image_seq_len)
    token_ids.append(image_end_token_id)
    return token_ids


@dataclass
class TrinityVLMCausalLMOutputWithPast(ModelOutput):
    loss: torch.FloatTensor | None = None
    aux_loss: torch.FloatTensor | None = None
    logits: torch.FloatTensor | None = None
    past_key_values: Cache | None = None
    hidden_states: tuple[torch.FloatTensor] | None = None
    attentions: tuple[torch.FloatTensor] | None = None
    router_logits: tuple[torch.FloatTensor] | None = None
    image_hidden_states: torch.FloatTensor | None = None


class VisionBridge(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, dtype: torch.dtype) -> None:
        super().__init__()
        self.norm = nn.LayerNorm(in_dim, dtype=dtype)
        self.fc1 = nn.Linear(in_dim, hidden_dim, dtype=dtype)
        self.fc2 = nn.Linear(hidden_dim, out_dim, dtype=dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.norm(x)
        x = torch.nn.functional.gelu(self.fc1(x), approximate="tanh")
        return self.fc2(x)


class TrinityVLMForConditionalGeneration(PreTrainedModel, GenerationMixin):
    config_class = TrinityVLMConfig
    base_model_prefix = "trinity_vlm"
    main_input_name = "input_ids"

    def __init__(self, config: TrinityVLMConfig) -> None:
        super().__init__(config)
        torch_dtype = self._resolve_torch_dtype(config)
        vision_config = dict(config.vision_config)
        projector_hidden_dim = int(vision_config.pop("projector_hidden_dim", config.projector_hidden_dim))

        self.torch_dtype = torch_dtype
        self.language_model = AfmoeForCausalLM(self._load_trinity_text_config(config))
        self.vision_tower = MoondreamVisionTower(
            config=MoondreamVisionConfig(**vision_config),
            dtype=torch_dtype,
        )
        self.multi_modal_projector = VisionBridge(
            in_dim=config.vision_feature_dim,
            hidden_dim=projector_hidden_dim,
            out_dim=self.language_model.config.hidden_size,
            dtype=torch_dtype,
        )

        self.config.hidden_size = self.language_model.config.hidden_size
        self.config.vocab_size = self.language_model.config.vocab_size
        self.config.bos_token_id = self.language_model.config.bos_token_id
        self.config.eos_token_id = self.language_model.config.eos_token_id
        self.config.pad_token_id = self.language_model.config.pad_token_id
        self.post_init()

    @staticmethod
    def _resolve_torch_dtype(config: TrinityVLMConfig) -> torch.dtype:
        dtype_value = getattr(config, "dtype", None) or getattr(config, "torch_dtype", None)
        if isinstance(dtype_value, str) and hasattr(torch, dtype_value):
            return getattr(torch, dtype_value)
        if isinstance(dtype_value, torch.dtype):
            return dtype_value
        return torch.bfloat16

    @staticmethod
    def _load_trinity_text_config(config: TrinityVLMConfig) -> AfmoeConfig:
        if not getattr(config, "text_config", None):
            raise ValueError("TrinityVLMConfig.text_config must be present in config.json.")
        text_config = AfmoeConfig(**config.text_config)
        text_config.vocab_size = config.vocab_size
        text_config.bos_token_id = config.bos_token_id
        text_config.eos_token_id = config.eos_token_id
        text_config.pad_token_id = config.pad_token_id
        text_config.packed_experts = True
        text_config.enable_grouped_moe = getattr(config, "enable_grouped_moe", True)
        text_config.output_router_logits = getattr(config, "output_router_logits", False)
        text_config._attn_implementation = "sdpa"
        return text_config

    @property
    def device(self) -> torch.device:
        return self.get_input_embeddings().weight.device

    @property
    def dtype(self) -> torch.dtype:
        return self.get_input_embeddings().weight.dtype

    def get_input_embeddings(self):
        return self.language_model.get_input_embeddings()

    def set_input_embeddings(self, value):
        self.language_model.set_input_embeddings(value)

    def get_output_embeddings(self):
        return self.language_model.get_output_embeddings()

    def set_output_embeddings(self, new_embeddings):
        self.language_model.set_output_embeddings(new_embeddings)

    def get_decoder(self):
        return self.language_model.get_decoder()

    def set_decoder(self, decoder):
        self.language_model.set_decoder(decoder)

    def build_image_token_span(self, *, include_bos: bool = True) -> list[int]:
        return build_image_token_span(
            image_start_token_id=self.config.image_start_token_id,
            image_token_id=self.config.image_token_id,
            image_end_token_id=self.config.image_end_token_id,
            image_seq_len=self.config.image_seq_len,
            bos_token_id=self.config.bos_token_id if include_bos else None,
        )

    def _project_image_feature_tensor(self, image_features: torch.Tensor) -> torch.Tensor:
        if image_features.shape[-1] == self.language_model.config.hidden_size:
            return image_features.to(device=self.device, dtype=self.dtype)
        if image_features.shape[-1] == self.config.vision_feature_dim:
            return self.multi_modal_projector(image_features.to(device=self.device, dtype=self.dtype))
        raise ValueError("Tensor image features must already be in text hidden size or vision feature size.")

    @torch.compiler.disable
    def get_image_features(
        self,
        images: list[Any] | list[list[Any]] | torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor] | None:
        if images is None:
            return None

        if isinstance(images, torch.Tensor):
            if images.ndim == 3:
                projected = self._project_image_feature_tensor(images)
                counts = torch.ones(projected.size(0), device=projected.device, dtype=torch.long)
                return projected, counts
            if images.ndim == 4:
                batch, num_images = images.shape[:2]
                projected = self._project_image_feature_tensor(images.flatten(0, 1))
                counts = torch.full((batch,), num_images, device=projected.device, dtype=torch.long)
                return projected, counts
            raise ValueError("Tensor images must have shape [n_images, seq, dim] or [batch, images, seq, dim].")

        if not isinstance(images, (list, tuple)):
            raise TypeError(f"Unsupported image batch type: {type(images)!r}")

        if not images:
            empty_features = torch.empty(
                0,
                self.config.image_seq_len,
                self.language_model.config.hidden_size,
                device=self.device,
                dtype=self.dtype,
            )
            empty_counts = torch.empty(0, device=self.device, dtype=torch.long)
            return empty_features, empty_counts

        first_item = images[0]
        if isinstance(first_item, Image.Image):
            image_batches = [[image] for image in images]
        elif isinstance(first_item, (list, tuple)):
            image_batches = [list(sample_images) for sample_images in images]
        else:
            raise TypeError(f"Unsupported image batch type: {type(first_item)!r}")

        flat_images: list[Image.Image] = []
        image_counts = []
        for sample_images in image_batches:
            for image in sample_images:
                if not isinstance(image, Image.Image):
                    raise TypeError(f"Expected PIL images, got {type(image)!r}")
            flat_images.extend(sample_images)
            image_counts.append(len(sample_images))

        if not flat_images:
            empty_features = torch.empty(
                0,
                self.config.image_seq_len,
                self.language_model.config.hidden_size,
                device=self.device,
                dtype=self.dtype,
            )
            return empty_features, torch.tensor(image_counts, device=self.device, dtype=torch.long)

        image_features = self.vision_tower.encode_images(flat_images).to(device=self.device, dtype=self.dtype)
        projected = self.multi_modal_projector(image_features)
        return projected, torch.tensor(image_counts, device=self.device, dtype=torch.long)

    def _get_placeholder_mask(
        self,
        input_ids: torch.LongTensor | None,
        inputs_embeds: torch.FloatTensor,
        image_features: torch.FloatTensor,
    ) -> torch.BoolTensor:
        if input_ids is None:
            image_token = torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
            special_image_mask = inputs_embeds == self.get_input_embeddings()(image_token)
            special_image_mask = special_image_mask.all(-1)
        else:
            special_image_mask = input_ids == self.config.image_token_id

        n_image_tokens = special_image_mask.sum()
        n_image_features = image_features.shape[0] * image_features.shape[1]
        torch._assert(
            n_image_tokens == n_image_features,
            f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
        )
        return special_image_mask

    def _merge_image_features(
        self,
        input_ids: torch.LongTensor | None,
        inputs_embeds: torch.FloatTensor,
        image_features: torch.FloatTensor,
        image_counts: torch.LongTensor,
    ) -> torch.FloatTensor:
        if input_ids is not None:
            num_images = image_counts.to(dtype=torch.long)
            expected_image_token_counts = num_images * image_features.shape[1]
            actual_image_token_counts = (input_ids == self.config.image_token_id).sum(dim=1)
            torch._assert(
                torch.all(actual_image_token_counts == expected_image_token_counts),
                "Image placeholder count mismatch.",
            )
            if self.config.image_start_token_id is not None:
                start_counts = (input_ids == self.config.image_start_token_id).sum(dim=1)
                torch._assert(torch.all(start_counts == num_images), "image_start token count mismatch.")
            if self.config.image_end_token_id is not None:
                end_counts = (input_ids == self.config.image_end_token_id).sum(dim=1)
                torch._assert(torch.all(end_counts == num_images), "image_end token count mismatch.")

        special_image_mask = self._get_placeholder_mask(input_ids, inputs_embeds, image_features)
        special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
        flat_image_features = image_features.reshape(-1, image_features.shape[-1]).to(
            inputs_embeds.device,
            inputs_embeds.dtype,
        )
        return inputs_embeds.masked_scatter(special_image_mask, flat_image_features)

    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        images: list[Any] | list[list[Any]] | torch.Tensor | None = None,
        attention_mask: torch.Tensor | dict[str, torch.Tensor] | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: Cache | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        labels: torch.LongTensor | None = None,
        logits_to_keep: int | torch.Tensor = 0,
        use_cache: bool | None = None,
        **kwargs,
    ) -> tuple | TrinityVLMCausalLMOutputWithPast:
        if (input_ids is None) == (inputs_embeds is None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)

        image_hidden_states = None
        if images is not None:
            image_outputs = self.get_image_features(images)
            if image_outputs is not None:
                image_features, image_counts = image_outputs
                if image_features.numel() > 0:
                    image_hidden_states = image_features
                    inputs_embeds = self._merge_image_features(
                        input_ids=input_ids,
                        inputs_embeds=inputs_embeds,
                        image_features=image_features,
                        image_counts=image_counts,
                    )

        outputs = self.language_model.model(
            input_ids=None,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            **kwargs,
        )

        hidden_states = outputs.last_hidden_state
        aux_loss = getattr(self.language_model.model, "_last_router_aux_loss", None)
        aux_loss_coef = float(getattr(self.config, "router_aux_loss_coef", 0.0) or 0.0)
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.language_model.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
        if loss is not None and aux_loss is not None and aux_loss_coef > 0.0:
            loss = loss + (aux_loss * aux_loss_coef)

        return TrinityVLMCausalLMOutputWithPast(
            loss=loss,
            aux_loss=aux_loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            router_logits=outputs.router_logits,
            image_hidden_states=image_hidden_states,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        inputs_embeds=None,
        images=None,
        attention_mask=None,
        logits_to_keep=None,
        is_first_iteration=False,
        **kwargs,
    ):
        model_inputs = super().prepare_inputs_for_generation(
            input_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            logits_to_keep=logits_to_keep,
            is_first_iteration=is_first_iteration,
            **kwargs,
        )

        if is_first_iteration or not kwargs.get("use_cache", True):
            model_inputs["images"] = images

        return model_inputs


__all__ = [
    "AfmoeConfig",
    "AfmoeForCausalLM",
    "AfmoeModel",
    "AfmoePreTrainedModel",
    "MoondreamVisionConfig",
    "MoondreamVisionTower",
    "TrinityVLMForConditionalGeneration",
]