File size: 46,222 Bytes
67e9774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
# mypy: allow-untyped-defs
import functools
import logging
from typing import Any, Optional

import sympy

import torch
from torch._dynamo.utils import counters
from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm
from torch._inductor.autoheuristic.autoheuristic_utils import (
    AHContext,
    context_add_strides,
    context_add_using_tf32,
    mm_operations,
)
from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.torch_version import TorchVersion

from .. import config as inductor_config, ir
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
from ..codegen.subgraph import SubgraphTemplate
from ..ir import FlexibleLayout, is_triton
from ..lowering import (
    add_layout_constraint,
    constrain_to_fx_strides,
    lowerings as L,
    register_lowering,
)
from ..select_algorithm import (
    autotune_select_algorithm,
    ExternKernelChoice,
    realize_inputs,
    TritonTemplate,
)
from ..utils import (
    _use_cutlass_for_op,
    get_k_splits,
    get_tma_workspace_arg,
    use_aten_gemm_kernels,
    use_ck_gemm_template,
    use_ck_tile_gemm_template,
    use_cpp_gemm_template,
    use_cutlass_template,
    use_decompose_k_choice,
    use_triton_template,
    use_triton_tma_template,
)
from .mm_common import (
    _is_static_problem,
    addmm_epilogue,
    mm_args,
    mm_config_kwargs,
    mm_grid,
    mm_options,
    persistent_mm_grid,
    persistent_mm_options,
    scale_mm_epilogue,
    scaled_mm_options,
)


try:
    import triton

    triton_version = TorchVersion(triton.__version__)
    has_triton = True
except ImportError:
    triton_version = TorchVersion("0.0.0")
    has_triton = False

log = logging.getLogger(__name__)
aten = torch.ops.aten
prims = torch.ops.prims

mm_template = TritonTemplate(
    name="mm",
    grid=mm_grid,
    source=(
        r"""

{{def_kernel("A", "B")}}

    M = {{size("A", 0)}}

    N = {{size("B", 1)}}

    K = {{size("A", 1)}}

    if M * N == 0:

        # early exit due to zero-size input(s)

        return

    stride_am = {{stride("A", 0)}}

    stride_ak = {{stride("A", 1)}}

    stride_bk = {{stride("B", 0)}}

    stride_bn = {{stride("B", 1)}}



    # based on triton.ops.matmul

    pid = tl.program_id(0)

    grid_m = (M + BLOCK_M - 1) // BLOCK_M

    grid_n = (N + BLOCK_N - 1) // BLOCK_N



    # re-order program ID for better L2 performance

    width = GROUP_M * grid_n

    group_id = pid // width

    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)

    pid_m = group_id * GROUP_M + (pid % group_size)

    pid_n = (pid % width) // (group_size)

    tl.assume(pid_m >= 0)

    tl.assume(pid_n >= 0)



    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)

    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

    if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and M >= BLOCK_M:

        offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)

    else:

        offs_a_m = rm % M

    if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and N >= BLOCK_N:

        offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)

    else:

        offs_b_n = rn % N

    offs_k = tl.arange(0, BLOCK_K)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)



    for k_idx in range(0, tl.cdiv(K, BLOCK_K)):

        {% if not EVEN_K %}

        a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K)

        b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K)

        {% endif %}

        a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K)

        b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K)



        idx_m = offs_a_m[:, None]

        idx_n = a_k_idx_vals

        {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", indent_width=8)}}



        idx_m = b_k_idx_vals

        idx_n = offs_b_n[None, :]

        {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}



        {% if USE_FAST_ACCUM %}

        acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)

        {% else %}

        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)

        {% endif %}



    # rematerialize rm and rn to save registers

    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)

    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

    idx_m = rm[:, None]

    idx_n = rn[None, :]

    mask = (idx_m < M) & (idx_n < N)



    # inductor generates a suffix

    {{store_output(("idx_m", "idx_n"), "acc", "mask")}}

"""
        if (torch.version.hip is None) or triton_version >= "3.3.0"
        # FIXME: To get around rocm failures like https://github.com/pytorch/pytorch/actions/runs/13123783322/job/36617154943
        # The only difference between the two templates is M >= BLOCK_M and N >= BLOCK_N checking.
        # See more details in https://github.com/pytorch/pytorch/pull/146293
        else r"""

{{def_kernel("A", "B")}}

    M = {{size("A", 0)}}

    N = {{size("B", 1)}}

    K = {{size("A", 1)}}

    if M * N == 0:

        # early exit due to zero-size input(s)

        return

    stride_am = {{stride("A", 0)}}

    stride_ak = {{stride("A", 1)}}

    stride_bk = {{stride("B", 0)}}

    stride_bn = {{stride("B", 1)}}



    # based on triton.ops.matmul

    pid = tl.program_id(0)

    grid_m = (M + BLOCK_M - 1) // BLOCK_M

    grid_n = (N + BLOCK_N - 1) // BLOCK_N



    # re-order program ID for better L2 performance

    width = GROUP_M * grid_n

    group_id = pid // width

    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)

    pid_m = group_id * GROUP_M + (pid % group_size)

    pid_n = (pid % width) // (group_size)

    tl.assume(pid_m >= 0)

    tl.assume(pid_n >= 0)



    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)

    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

    if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):

        offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)

    else:

        offs_a_m = rm % M

    if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):

        offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)

    else:

        offs_b_n = rn % N

    offs_k = tl.arange(0, BLOCK_K)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)



    for k_idx in range(0, tl.cdiv(K, BLOCK_K)):

        {% if not EVEN_K %}

        a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K)

        b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K)

        {% endif %}

        a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K)

        b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K)



        idx_m = offs_a_m[:, None]

        idx_n = a_k_idx_vals

        {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", indent_width=8)}}



        idx_m = b_k_idx_vals

        idx_n = offs_b_n[None, :]

        {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}

        {% if USE_FAST_ACCUM %}

        acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)

        {% else %}

        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)

        {% endif %}



    # rematerialize rm and rn to save registers

    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)

    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

    idx_m = rm[:, None]

    idx_n = rn[None, :]

    mask = (idx_m < M) & (idx_n < N)



    # inductor generates a suffix

    {{store_output(("idx_m", "idx_n"), "acc", "mask")}}

"""
    ),
    cache_codegen_enabled_for_template=True,
    prologue_loads_all_inputs=True,
)

persistent_tma_mm_template = TritonTemplate(
    name="mm_persistent_tma",
    grid=persistent_mm_grid,
    source=r"""

{{def_kernel("A", "B")}}

    M = {{size("A", 0)}}

    N = {{size("B", 1)}}

    K = {{size("A", 1)}}

    if M * N == 0:

        # early exit due to zero-size input(s)

        return



    start_pid = tl.program_id(0)

    grid_m = tl.cdiv(M, BLOCK_M)

    grid_n = tl.cdiv(N, BLOCK_N)

    k_tiles = tl.cdiv(K, BLOCK_K)

    num_tiles = grid_m * grid_n

    tiles_per_SM = num_tiles // NUM_SMS

    if start_pid < num_tiles % NUM_SMS:

        tiles_per_SM += 1



    tile_id = start_pid - NUM_SMS

    ki = -1



    width = GROUP_M * grid_n

    rk_for_mask = tl.arange(0, BLOCK_K)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)



    workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE

    a_desc_ptr = workspace_base

    b_desc_ptr = workspace_base + TMA_SIZE



    {%- if TMA_EXPERIMENTAL_API %}

    triton.language.extra.cuda.experimental_device_tensormap_create2d(

        desc_ptr=a_desc_ptr,

        global_address=A,

        load_size=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],

        global_size=[M, K] if A_ROW_MAJOR else [K, M],

        element_ty=A.dtype.element_ty,

    )

    triton.language.extra.cuda.experimental_device_tensormap_create2d(

        desc_ptr=b_desc_ptr,

        global_address=B,

        load_size=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],

        global_size=[K, N] if B_ROW_MAJOR else [N, K],

        element_ty=B.dtype.element_ty,

    )



    tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)

    tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)



    a_desc = a_desc_ptr

    b_desc = b_desc_ptr

    {%- else %}

    a_desc = triton.language.make_tensor_descriptor(

        base=A,

        shape=[M, K] if A_ROW_MAJOR else [K, M],

        strides=[K, 1] if A_ROW_MAJOR else [M, 1],

        block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],

    )

    b_desc = triton.language.make_tensor_descriptor(

        base=B,

        shape=[K, N] if B_ROW_MAJOR else [N, K],

        strides=[N, 1] if B_ROW_MAJOR else [K, 1],

        block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],

    )

    {%- endif %}



    pid_m = 0

    pid_n = 0

    rm = 0

    rn = 0



    for _ in range(0, k_tiles * tiles_per_SM):

        ki = tl.where(ki == k_tiles - 1, 0, ki + 1)

        if ki == 0:

            tile_id += NUM_SMS

            # re-order program ID for better L2 performance

            group_id = tile_id // width

            group_size = min(grid_m - group_id * GROUP_M, GROUP_M)

            pid_m = group_id * GROUP_M + (tile_id % group_size)

            pid_n = (tile_id % width) // (group_size)



            rm = pid_m * BLOCK_M

            rn = pid_n * BLOCK_N



        rk = ki * BLOCK_K



        {%- if TMA_EXPERIMENTAL_API %}

        a = tl._experimental_descriptor_load(

            a_desc,

            [rm, rk] if A_ROW_MAJOR else [rk, rm],

            [BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],

            A.dtype.element_ty,

        )

        b = tl._experimental_descriptor_load(

            b_desc,

            [rk, rn] if B_ROW_MAJOR else [rn, rk],

            [BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],

            B.dtype.element_ty,

        )

        {%- else %}

        a = tl.load_tensor_descriptor(

            a_desc,

            [rm, rk] if A_ROW_MAJOR else [rk, rm],

        )

        b = tl.load_tensor_descriptor(

            b_desc,

            [rk, rn] if B_ROW_MAJOR else [rn, rk],

        )

        {%- endif %}

        acc += tl.dot(

            a if A_ROW_MAJOR else a.T,

            b if B_ROW_MAJOR else b.T,

            allow_tf32=ALLOW_TF32,

        )



        if ki == k_tiles - 1:

            # rematerialize rm and rn to save registers

            rcm = rm + tl.arange(0, BLOCK_M)

            rcn = rn + tl.arange(0, BLOCK_N)

            idx_m = rcm[:, None]

            idx_n = rcn[None, :]

            mask = (idx_m < M) & (idx_n < N)



            # inductor generates a suffix

            {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}}

            acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)



""",
)

load_scales = r"""

@triton.jit

def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr):

    if SCALING_ROWWISE:

        # For row-wise scaling, we'll return the pointers

        return a_scale_ptr, b_scale_ptr

    else:

        # For per-tensor scaling, we'll load the scalar values

        a_scale = tl.load(a_scale_ptr)

        b_scale = tl.load(b_scale_ptr)

        return a_scale, b_scale

"""


apply_scaling = r"""

@triton.jit

def apply_scaling(

    accumulator,

    a_scale,

    b_scale,

    SCALING_ROWWISE: tl.constexpr,

    offs_cm,

    offs_cn,

    M,

    N,

    stride_a_scale_m,

    stride_b_scale_n,

):

    if SCALING_ROWWISE:

        # For row-wise scaling, we need to load the scales for each row/column

        a_scales = tl.load(

            a_scale + (offs_cm * stride_a_scale_m),

            mask=offs_cm < M,

            other=0.0,

        )

        b_scales = tl.load(

            b_scale + (offs_cn * stride_b_scale_n),

            mask=offs_cn < N,

            other=0.0,

        )

        acc_scale = a_scales[:, None] * b_scales[None, :]

    else:

        # For per-tensor scaling, we can directly use the loaded scalar values

        acc_scale = a_scale * b_scale



    return accumulator * acc_scale

"""


device_tma = r"""

{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}}

    M = {{size("A", 0)}}

    N = {{size("B", 1)}}

    K = {{size("A", 1)}}

    if M * N == 0:

        # early exit due to zero-size input(s)

        return



    stride_am = {{stride("A", 0)}}

    stride_ak = {{stride("A", 1)}}

    stride_bk = {{stride("B", 0)}}

    stride_bn = {{stride("B", 1)}}



    if SCALING_ROWWISE:

        stride_a_scale_m = 1

        stride_b_scale_n = 1

    else:

        stride_a_scale_m = 0

        stride_b_scale_n = 0



    start_pid = tl.program_id(axis=0)

    num_pid_m = tl.cdiv(M, BLOCK_M)

    num_pid_n = tl.cdiv(N, BLOCK_N)

    k_tiles = tl.cdiv(K, BLOCK_K)

    num_tiles = num_pid_m * num_pid_n



    workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE

    a_desc_ptr = workspace_base

    b_desc_ptr = workspace_base + TMA_SIZE



    {%- if TMA_EXPERIMENTAL_API %}

    triton.language.extra.cuda.experimental_device_tensormap_create2d(

        desc_ptr=a_desc_ptr,

        global_address=A,

        load_size=[BLOCK_M, BLOCK_K],

        global_size=[M, K],

        element_ty=A.dtype.element_ty,

    )

    triton.language.extra.cuda.experimental_device_tensormap_create2d(

        desc_ptr=b_desc_ptr,

        global_address=B,

        load_size=[BLOCK_N, BLOCK_K],

        global_size=[N, K],

        element_ty=B.dtype.element_ty,

    )



    tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)

    tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)



    a_desc = a_desc_ptr

    b_desc = a_desc_ptr

    {%- else %}

    a_desc = triton.language.make_tensor_descriptor(

        base=A,

        shape=[M, K],

        strides=[K, 1],

        block_shape=[BLOCK_M, BLOCK_K],

    )

    b_desc = triton.language.make_tensor_descriptor(

        base=B,

        shape=[N, K],

        strides=[K, 1],

        block_shape=[BLOCK_N, BLOCK_K],

    )

    {%- endif %}



    tiles_per_SM = num_tiles // NUM_SMS

    if start_pid < num_tiles % NUM_SMS:

        tiles_per_SM += 1



    tile_id = start_pid - NUM_SMS

    ki = -1



    pid_m = 0

    pid_n = 0

    offs_am = 0

    offs_bn = 0



    num_pid_in_group = GROUP_M * num_pid_n

    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)

    a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE)



    for _ in range(0, k_tiles * tiles_per_SM):

        ki = tl.where(ki == k_tiles - 1, 0, ki + 1)

        if ki == 0:

            tile_id += NUM_SMS

            group_id = tile_id // num_pid_in_group

            first_pid_m = group_id * GROUP_M

            group_size_m = min(num_pid_m - first_pid_m, GROUP_M)

            pid_m = first_pid_m + (tile_id % group_size_m)

            pid_n = (tile_id % num_pid_in_group) // group_size_m



            offs_am = pid_m * BLOCK_M

            offs_bn = pid_n * BLOCK_N



        offs_k = ki * BLOCK_K



        {%- if TMA_EXPERIMENTAL_API %}

        a = tl._experimental_descriptor_load(

            a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K],  A.dtype.element_ty

        )

        b = tl._experimental_descriptor_load(

            b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K],  B.dtype.element_ty

        )

        {%- else %}

        a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k])

        b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k])

        {%- endif %}

        if USE_FAST_ACCUM:

            accumulator = tl.dot(a, b.T, accumulator)

        else:

            accumulator += tl.dot(a, b.T)



        if ki == k_tiles - 1:

            # Apply inverse scaling

            offs_cm = offs_am + tl.arange(0, BLOCK_M)

            offs_cn = offs_bn + tl.arange(0, BLOCK_N)

            # Apply scaling

            accumulator = apply_scaling(

                accumulator,

                a_scale,

                b_scale,

                SCALING_ROWWISE,

                offs_cm,

                offs_cn,

                M,

                N,

                stride_a_scale_m,

                stride_b_scale_n,

            )



            idx_m = offs_cm[:, None]

            idx_n = offs_cn[None, :]

            mask = (idx_m < M) & (idx_n < N)

            # inductor generates a suffix

            {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}}

            accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

"""


scaled_mm_device_tma_template = TritonTemplate(
    name="scaled_mm_device_tma",
    grid=persistent_mm_grid,
    source=device_tma + load_scales + apply_scaling,
)


# prevent duplication registration of extern functions
@functools.cache
def lazy_register_extern_choice(fn):
    return ExternKernelChoice(fn)


aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")

aten_addmm = ExternKernelChoice(
    torch.addmm, "at::addmm_out", op_overload=aten.addmm.default
)

aten__int_mm = ExternKernelChoice(torch._int_mm, "at::_int_mm_out")

aten__sparse_semi_structured_mm = ExternKernelChoice(
    torch._sparse_semi_structured_mm,
    "at::_sparse_semi_structured_mm",
    has_out_variant=False,
)

aten__fp8_mm = ExternKernelChoice(
    torch._scaled_mm, "at::_scaled_mm_out", op_overload=aten._scaled_mm.out
)


def _is_int8_mat(mat):
    return mat.get_dtype() in (torch.int8, torch.uint8)


def _is_large_block_for_cpu(m, n, k):
    # Thresholds are experimentally determined to reduce Triton CPU compile times
    return m * n > 2**13


@functools.lru_cache
def using_b200() -> bool:
    """Returns true if the device is a NVIDIA B200, otherwise returns false."""
    if not torch.cuda.is_available():
        return False
    # compute capability 10.0 or 10.0a is NVIDIA B200
    device_properties = torch.cuda.get_device_properties(torch.cuda.current_device())
    return device_properties.major == 10


def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
    """

    Giving torch.addmm a 1D tensor calls a different (faster) cublasLt

    kernel under the hood.  There are a few shapes where this is slower,

    but they are rare.

    """
    if inp.stride(0) == 0 or inp.size(0) == 1:
        return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta)
    return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta)


def check_supported_striding(mat_a, mat_b) -> None:
    def is_row_major(stride) -> bool:
        return V.graph.sizevars.statically_known_equals(stride[1], 1)

    def is_col_major(stride) -> bool:
        return V.graph.sizevars.statically_known_equals(stride[0], 1)

    def has_zero_dim(size) -> bool:
        return bool(
            V.graph.sizevars.statically_known_equals(size[0], 0)
            or V.graph.sizevars.statically_known_equals(size[1], 0)
        )

    # Check mat_a (self) stride requirements
    torch._check(
        is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()),
        lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}",
    )

    # Check mat_b stride requirements
    torch._check(
        is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()),
        lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}",
    )


aten_bias_addmm = ExternKernelChoice(bias_addmm, None)


def decomposeK(a, b, k_splits):
    m = a.shape[0]
    n = b.shape[1]
    k = a.shape[1]

    k_parts = k // k_splits
    B = k_splits
    a_reshaped = torch.permute(a.reshape(m, B, k_parts), (1, 0, 2))
    b_reshaped = b.reshape(B, k_parts, n)
    result = torch.bmm(a_reshaped, b_reshaped, out_dtype=torch.float32)
    reduced_buf = torch.sum(result, 0)
    return reduced_buf.to(a.dtype)


@register_lowering(aten.mm, type_promotion_kind=None)
def tuned_mm(mat1, mat2, *, layout=None):
    """

    Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.)

    """
    m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
    device_type = ir.get_device_type(mat1)
    name = "mm"

    # below is for getting an overview logging info of inductor mms
    counters["aten_mm_info"][f"aten.mm_{m}_{n}_{k}"] += 1
    log.info(
        "Tuned aten.mm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
        m,
        n,
        k,
        mat1.get_dtype(),
        mat2.get_dtype(),
        layout,
    )

    aten_layout = layout
    if not (inductor_config.max_autotune or inductor_config.max_autotune_gemm):
        aten_layout = FlexibleLayout(
            device=layout.device, dtype=layout.dtype, size=layout.size
        )

    # options to tune from
    choices = (
        [aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
    )
    static_shape, is_nonzero = _is_static_problem(layout)

    mm_configs = V.choices.get_base_mm_configs(device_type)
    persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
    extra_mm_configs = V.choices.get_extra_mm_configs(device_type)

    dtype = mat1.get_dtype()
    if is_nonzero and use_triton_template(layout):
        for config in mm_configs(
            m,
            n,
            k,
            **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize),
        ):
            mm_template.maybe_append_choice(
                choices,
                input_nodes=(mat1, mat2),
                layout=layout,
                **mm_options(config, m, n, k, layout),
            )

        if use_triton_tma_template(mat1, mat2):
            for config in persistent_mm_configs(
                m,
                n,
                k,
                **mm_config_kwargs(
                    device_type, _is_large_block_for_cpu, dtype.itemsize
                ),
            ):
                persistent_tma_mm_template.maybe_append_choice(
                    choices,
                    input_nodes=(mat1, mat2),
                    layout=layout,
                    workspace_arg=get_tma_workspace_arg(
                        num_tma_descriptors=2,
                        device=mat1.get_device(),
                    ),
                    **mm_options(config, m, n, k, layout),
                    **persistent_mm_options(mat1, mat2),
                )

        from torch._inductor.ir import get_free_symbols

        # Only do split-k optimization if K is much larger than m, n and m, n are small
        # and if there aren't any unbacked symbols
        unbacked_symbols = any(
            len(get_free_symbols(itr, unbacked_only=True)) > 0
            for itr in (
                mat1.get_size(),
                mat1.get_stride(),
                mat2.get_size(),
                mat2.get_stride(),
            )
        )
        if use_decompose_k_choice(m, n, k) and not unbacked_symbols:
            from torch._dispatch.python import enable_python_dispatcher

            from ..decomposition import select_decomp_table

            k_splits = get_k_splits(m, n, k)
            for k_split in k_splits:
                if not V.graph.sizevars.statically_known_true(
                    sympy.Eq(sympy.Mod(k, k_split), 0)
                ):
                    continue

                with enable_python_dispatcher():
                    decompositions = select_decomp_table()

                    decompose_k_subgraph_template = SubgraphTemplate(
                        name=f"decompose_k_mm_{k_split}_split",
                        make_fx_graph=make_fx(
                            functools.partial(decomposeK, k_splits=k_split),
                            decompositions,
                        ),
                    )

                decompose_k_subgraph_template.maybe_append_choice(
                    choices,
                    input_nodes=(mat1, mat2),
                    layout=layout,
                )

    if (
        is_nonzero
        and use_cutlass_template(layout, m, n, k)
        and _use_cutlass_for_op("mm")
    ):
        CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])

    if is_nonzero and use_ck_gemm_template(layout, m, n, k):
        CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2])
    if is_nonzero and use_ck_tile_gemm_template(layout, m, n, k):
        CKTileGemmTemplate.add_choices(choices, layout, [mat1, mat2])

    if use_cpp_gemm_template(layout, mat1, mat2):
        CppGemmTemplate.add_choices(
            choices,
            layout,
            [mat1, mat2],
        )

    input_nodes = [mat1, mat2]
    if (
        is_nonzero
        and use_triton_template(layout)
        and torch._inductor.config.run_autoheuristic(name)
        and is_triton(mat1)
    ):
        always_included = []
        if use_aten_gemm_kernels():
            always_included.append("extern_mm")
        num_choices_before_extra_configs = len(choices)
        for config in extra_mm_configs(
            m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
        ):
            mm_template.maybe_append_choice(
                choices,
                input_nodes=(mat1, mat2),
                layout=layout,
                **mm_options(config, m, n, k, layout),
            )

        # using AutoHeuristic for ranking
        ah_choices = mm_autoheuristic(
            mat1,
            mat2,
            m,
            n,
            k,
            choices,
            name,
            input_nodes,
            mm_operations(),
            None,
            top_k=10,
            always_included=always_included,
        )
        if not torch._inductor.config.collect_autoheuristic(name):
            # if we are collecting data, we do not want to modify choices
            if ah_choices is not None and len(ah_choices) > 0:
                # the order in which autoheuristic returns choices is not the same as
                # as the order of choices, which affects things like epilogue fusion.
                # once epilogue fusion benchmarks choices in sorted order, I think we can
                # just use the order returned by autoheuristic
                choices = [choice for choice in choices if choice in ah_choices]
            else:
                choices = choices[:num_choices_before_extra_configs]

    for k in inductor_config.external_matmul:
        choices.append(lazy_register_extern_choice(k).bind((mat1, mat2), layout))

    return autotune_select_algorithm(name, choices, [mat1, mat2], layout)


@register_lowering(aten._int_mm, type_promotion_kind=None)
def tuned_int_mm(mat1, mat2, *, layout=None):
    m, n, k, layout, mat1, mat2 = mm_args(
        mat1, mat2, layout=layout, out_dtype=torch.int32
    )

    # below is for getting an overview logging info of inductor mms
    counters["aten_mm_info"][f"aten._int_mm_{m}_{n}_{k}"] += 1
    log.info(
        "Tuned aten._int_mm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
        m,
        n,
        k,
        mat1.get_dtype(),
        mat2.get_dtype(),
        layout,
    )

    device_type = ir.get_device_type(mat1)

    static_shape, is_nonzero = _is_static_problem(layout)
    use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)

    choices = (
        [aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
    )

    if use_cutlass and _use_cutlass_for_op("int_mm"):
        CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
            choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
        )

    int8_mm_configs = V.choices.get_int8_mm_configs(device_type)

    if is_nonzero and use_triton_template(layout, enable_int32=True):
        for config in int8_mm_configs(
            m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
        ):
            mm_template.maybe_append_choice(
                choices,
                input_nodes=(mat1, mat2),
                layout=layout,
                **mm_options(config, m, n, k, layout),
            )

    return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)


@register_lowering(aten.addmm, type_promotion_kind=None)
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
    device_type = ir.get_device_type(mat1)
    m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
    static_shape, is_nonzero = _is_static_problem(layout)

    # below is for getting an overview logging info of inductor mms
    counters["aten_mm_info"][f"aten.addmm_{m}_{n}_{k}"] += 1
    log.info(
        "Tuned aten.addmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
        m,
        n,
        k,
        mat1.get_dtype(),
        mat2.get_dtype(),
        layout,
    )

    if (not is_nonzero) or (
        not (inductor_config.max_autotune or inductor_config.max_autotune_gemm)
    ):
        # Use a FlexibleLayout if we are not autotuning.
        # This allows padding strides for the output.
        from torch._inductor.ir import FixedLayout, FlexibleLayout

        if isinstance(layout, FixedLayout):
            layout = FlexibleLayout(
                device=layout.device, dtype=layout.dtype, size=layout.size
            )
        choices = (
            [
                aten_addmm.bind(
                    (inp, mat1, mat2),
                    layout,
                    alpha=alpha,
                    beta=beta,
                )
            ]
            if use_aten_gemm_kernels()
            else []
        )
        return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout)

    choices = (
        [
            aten_addmm.bind(
                (inp_expanded, mat1, mat2),
                layout,
                alpha=alpha,
                beta=beta,
            )
        ]
        if use_aten_gemm_kernels()
        else []
    )

    if (
        use_aten_gemm_kernels()
        and inp_expanded.get_stride()[0] == 0
        and inp_expanded.get_device().type == "cuda"
        and inductor_config.triton.autotune_cublasLt
    ):
        # unexpand inp to make sure fused addmm from cublasLt is used
        choices.insert(
            0,
            aten_bias_addmm.bind(
                (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
            ),
        )

    mm_configs = V.choices.get_base_mm_configs(device_type)
    persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)

    dtype = mat1.get_dtype()
    if is_nonzero and use_triton_template(layout):
        for config in mm_configs(
            m,
            n,
            k,
            **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize),
        ):
            mm_template.maybe_append_choice(
                choices,
                input_nodes=(inp_expanded, mat1, mat2),
                layout=layout,
                **mm_options(config, m, n, k, layout),
                prefix_args=1,
                epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
                epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]),
            )

        if use_triton_tma_template(mat1, mat2):
            for config in persistent_mm_configs(
                m,
                n,
                k,
                **mm_config_kwargs(
                    device_type, _is_large_block_for_cpu, dtype.itemsize
                ),
            ):
                persistent_tma_mm_template.maybe_append_choice(
                    choices,
                    input_nodes=(inp_expanded, mat1, mat2),
                    layout=layout,
                    workspace_arg=get_tma_workspace_arg(
                        num_tma_descriptors=2,
                        device=mat1.get_device(),
                    ),
                    **mm_options(config, m, n, k, layout),
                    **persistent_mm_options(mat1, mat2),
                    prefix_args=1,
                    epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
                )

    if (
        is_nonzero
        and use_cutlass_template(layout, m, n, k)
        and _use_cutlass_for_op("addmm")
    ):
        CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
            choices,
            layout,
            [mat1, mat2, inp_expanded],
            alpha=alpha,
            beta=beta,
            input_reorder=[2, 0, 1],
        )

    if is_nonzero and use_ck_gemm_template(layout, m, n, k):
        CKGemmTemplate.add_ck_gemm_choices(
            choices,
            layout,
            [mat1, mat2, inp_expanded],
            alpha=alpha,
            beta=beta,
            input_reorder=[2, 0, 1],
        )

    if use_cpp_gemm_template(layout, mat1, mat2):
        CppGemmTemplate.add_choices(
            choices,
            layout,
            [inp_expanded, mat1, mat2],
            alpha=alpha,
            beta=beta,
            has_bias=True,
        )

    return autotune_select_algorithm(
        "addmm", choices, [inp_expanded, mat1, mat2], layout
    )


@register_lowering(aten._sparse_semi_structured_mm, type_promotion_kind=None)
def tuned_sparse_semi_structured_mm(

    mat1, mat1_meta, mat2, *, out_dtype=None, layout=None

):
    from torch._inductor.select_algorithm import realize_inputs

    mat1, mat1_meta, mat2 = realize_inputs(mat1, mat1_meta, mat2)
    m1, k1 = mat1.get_size()
    m2, _ = mat1_meta.get_size()
    k2, n = mat2.get_size()
    m = V.graph.sizevars.guard_equals(m1, m2)
    k = V.graph.sizevars.guard_equals(2 * k1, k2)

    if layout is None:
        from torch._inductor.ir import FixedLayout

        layout = FixedLayout(
            mat2.get_device(),
            out_dtype if out_dtype else mat2.get_dtype(),
            [m, n],
            [n, 1],
        )
    else:
        assert out_dtype is None, "out_dtype is ignored if layout is specified."

    choices = (
        [
            aten__sparse_semi_structured_mm.bind(
                (mat1, mat1_meta, mat2), layout, out_dtype=out_dtype
            )
        ]
        if use_aten_gemm_kernels()
        else []
    )

    if (
        m * n != 0
        and use_cutlass_template(layout, m, n, k)
        and _use_cutlass_for_op("sparse_semi_structured_mm")
    ):
        CUTLASS2xGemmTemplate.add_cutlass_gemm_choices(
            choices, layout, [mat1, mat2, mat1_meta], fuseable=True, non_fuseable=True
        )

    return autotune_select_algorithm(
        "sparse_semi_structured_mm", choices, [mat1, mat1_meta, mat2], layout
    )


add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides)


@register_lowering(aten._scaled_mm.default, type_promotion_kind=None)  # type: ignore[misc]
def tuned_scaled_mm(

    mat_a,

    mat_b,

    scale_a,

    scale_b,

    bias=None,

    scale_result=None,

    out_dtype=None,

    use_fast_accum=False,

    layout=None,

):
    """

    Performs an optimized matrix multiplication where scaling factors are applied

    to the inputs and/or output.



    Args:

        mat1 (Tensor): First input matrix

        mat2 (Tensor): Second input matrix

        scale1 (Tensor): Scale factor applied to mat1 (supports broadcasting)

        scale2 (Tensor): Scale factor applied to mat2 (supports broadcasting)

        bias (Tensor, optional): Optional bias tensor to add to the result

        layout: Layout hint for optimization



    Returns:

        Tensor: The result of the scaled matrix multiplication

    """
    m, n, k, layout, mat_a, mat_b = mm_args(
        mat_a, mat_b, layout=layout, out_dtype=out_dtype
    )
    # below is for getting an overview logging info of inductor mms
    counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1
    log.info(
        "Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
        m,
        n,
        k,
        mat_a.get_dtype(),
        mat_b.get_dtype(),
        layout,
    )

    device_type = ir.get_device_type(mat_a)
    check_supported_striding(mat_a, mat_b)

    scale_a_real, scale_b_real = realize_inputs(scale_a, scale_b)

    input_nodes: tuple[Any, ...]

    if not bias:
        input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real)
    else:
        bias_real = realize_inputs(bias)
        input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real, bias_real)

    aten_choice = aten__fp8_mm.bind(
        input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum
    )

    choices = []
    if use_aten_gemm_kernels():
        choices.append(aten_choice)

    # We dont have triton lowerings for the MX variants yet
    if scale_a.dtype != torch.float32:
        return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout)

    _, is_nonzero = _is_static_problem(layout)

    scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type)
    scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs(
        device_type
    )

    if is_nonzero and use_triton_template(layout, enable_float8=True):
        triton_input_nodes: tuple[Any, ...]
        if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1:
            # Need to unsqueeze bias from [N] -> [1, N]
            triton_bias = L[aten.unsqueeze](bias, 0)
        else:
            triton_bias = bias

        if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0:
            assert len(scale_a.get_size()) == len(scale_b.get_size())
            # Need to unsqueeze scale from [] -> [1, 1]
            triton_scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1)
            triton_scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1)
        else:
            triton_scale_a = scale_a
            triton_scale_b = scale_b

        if bias:
            triton_input_nodes = (
                mat_a,
                mat_b,
                triton_scale_a,
                triton_scale_b,
                triton_bias,
            )
            suffix_args = 3
        else:
            triton_input_nodes = (mat_a, mat_b, triton_scale_a, triton_scale_b)
            suffix_args = 2

        # TODO (paulzhan): There is no template that exists for bias and TMA
        # Don't run tma template currently if bias exists
        if use_triton_tma_template(mat_a, mat_b) and not bias:
            for config in scaled_persistent_mm_configs(m, n, k):
                kwargs = scaled_mm_options(
                    config,
                    m,
                    n,
                    k,
                    layout,
                    scale_a,
                    scale_b,
                    use_fast_accum,
                    device_tma=True,
                )
                scaled_mm_device_tma_template.maybe_append_choice(
                    choices,
                    input_nodes=triton_input_nodes,
                    layout=layout,
                    workspace_arg=get_tma_workspace_arg(
                        num_tma_descriptors=2,
                        device=mat_a.get_device(),
                    ),
                    **kwargs,
                )

        for config in scaled_mm_configs(m, n, k):
            if V.graph.sizevars.guard_or_false(sympy.Le(k, 16)):
                # Triton crashes however uncommon for real workloads
                continue

            # On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid
            # source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape
            if using_b200() and V.graph.sizevars.guard_or_false(sympy.Lt(k, 32)):
                continue

            kwargs = scaled_mm_options(
                config, m, n, k, layout, scale_a, scale_b, use_fast_accum
            )
            # possibly appends a TritonTemplateCaller to choices
            mm_template.maybe_append_choice(
                choices,
                input_nodes=triton_input_nodes,
                layout=layout,
                **kwargs,
                suffix_args=suffix_args,
                epilogue_fn=scale_mm_epilogue(),
                epilogue_fn_hash="scale_mm_epilogue",
            )

    if (
        is_nonzero
        and use_cutlass_template(layout, m, n, k)
        and _use_cutlass_for_op("scaled_mm")
    ):
        CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
            choices,
            layout,
            input_nodes,  # type: ignore[arg-type]
            use_fast_accum=use_fast_accum,  # type: ignore[arg-type]
        )

    if is_nonzero and use_ck_gemm_template(layout, m, n, k):
        CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes)

    return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout)


@functools.cache
def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool:
    props = torch.cuda.get_device_properties(index or 0)
    return props.major <= 7


def dims_are_int(dims):
    return all(isinstance(dim, int) for dim in dims)


def mm_autoheuristic(

    mat1,

    mat2,

    m,

    n,

    k,

    choices,

    name,

    input_nodes,

    ops,

    precondition,

    top_k: Optional[int] = None,

    always_included=None,

):
    m, n, k = get_size_hints(mat1, mat2, m, n, k)
    if not dims_are_int([m, n, k]):
        return None
    mat1_stride, mat2_stride = get_size_hints_strides(mat1, mat2)

    def get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride):
        context = AHContext()
        context.add_feature("m", m)
        context.add_feature("k", k)
        context.add_feature("n", n)
        context.add_feature("mat1_dtype", mat1.layout.dtype, is_categorical=True)
        context.add_feature("mat2_dtype", mat2.layout.dtype, is_categorical=True)
        context_add_strides(context, "mat1", mat1_stride)
        context_add_strides(context, "mat2", mat2_stride)
        context.add_feature(
            "mat1_iscontig", mat1.layout.is_contiguous(), is_categorical=True
        )
        context.add_feature(
            "mat2_iscontig", mat2.layout.is_contiguous(), is_categorical=True
        )
        if name == "mm":
            context_add_using_tf32(context, mat1.layout.dtype)
        return context

    def fallback():
        return None

    context = get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride)
    autoheuristic = AutoHeuristicSelectAlgorithm(
        fallback=fallback,
        choices=choices,
        input_nodes=input_nodes,
        context=context,
        name=name,
        augment_context=ops,
        precondition=precondition,
    )

    if top_k is not None:
        # TODO: is there a cleaner way to ensure aten.mm is always included?
        return autoheuristic.get_top_k_choices_caller(
            top_k, always_included=always_included
        )

    return autoheuristic.get_choice_caller()


def get_size_hints(mat1, mat2, m, n, k):
    if not isinstance(m, int) or not isinstance(k, int):
        (m, k) = V.graph.sizevars.size_hints(
            mat1.get_size(),
            fallback=torch._inductor.config.unbacked_symint_fallback,
        )

    if not isinstance(n, int) or not isinstance(k, int):
        (k, n) = V.graph.sizevars.size_hints(
            mat2.get_size(),
            fallback=torch._inductor.config.unbacked_symint_fallback,
        )
    return m, n, k


def get_size_hints_strides(mat1, mat2):
    mat1_stride = mat1.layout.stride
    mat2_stride = mat2.layout.stride
    strides = [mat1_stride, mat2_stride]
    strides_hints = []
    for stride in strides:
        if not isinstance(stride, int):
            stride = V.graph.sizevars.size_hints(
                stride,
                fallback=torch._inductor.config.unbacked_symint_fallback,
            )
        strides_hints.append(stride)
    return strides_hints[0], strides_hints[1]