lsmpp commited on
Commit
8934dae
·
verified ·
1 Parent(s): 613726b

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +177 -0
  2. .venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/__init__.cpython-312.pyc +0 -0
  3. .venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-312.pyc +0 -0
  4. .venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/__init__.cpython-312.pyc +0 -0
  5. .venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/python.cpython-312.pyc +0 -0
  6. .venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/__init__.cpython-312.pyc +0 -0
  7. .venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/config.cpython-312.pyc +0 -0
  8. .venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-312.pyc +0 -0
  9. .venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/custom_graph_pass.cpython-312.pyc +0 -0
  10. .venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/standalone_compile.cpython-312.pyc +0 -0
  11. .venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_operators.cpython-312.pyc +0 -0
  12. .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__init__.py +0 -0
  13. .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py +296 -0
  14. .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py +321 -0
  15. .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py +150 -0
  16. .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py +149 -0
  17. .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py +109 -0
  18. .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py +0 -0
  19. .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic.py +315 -0
  20. .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py +339 -0
  21. .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py +119 -0
  22. .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py +95 -0
  23. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/__init__.py +0 -0
  24. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py +31 -0
  25. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp +443 -0
  26. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/block_analysis.py +175 -0
  27. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/common.py +2691 -0
  28. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp.py +0 -0
  29. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_bmm_template.py +262 -0
  30. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_flex_attention_template.py +1081 -0
  31. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_gemm_template.py +1777 -0
  32. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_grouped_gemm_template.py +500 -0
  33. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py +2011 -0
  34. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template.py +138 -0
  35. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template_kernel.py +597 -0
  36. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_utils.py +776 -0
  37. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py +0 -0
  38. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +878 -0
  39. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_gpu.py +717 -0
  40. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_mps.py +99 -0
  41. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpu_device_op_overrides.py +27 -0
  42. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__init__.py +0 -0
  43. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +293 -0
  44. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_env.py +45 -0
  45. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py +674 -0
  46. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_template.py +318 -0
  47. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_cache.py +105 -0
  48. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py +0 -0
  49. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py +240 -0
  50. .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py +411 -0
.gitattributes CHANGED
@@ -1825,3 +1825,180 @@ illustrious_generated/e3bdc40b2136.png filter=lfs diff=lfs merge=lfs -text
1825
  illustrious_generated/6ea1f2330bb2.png filter=lfs diff=lfs merge=lfs -text
1826
  illustrious_generated/a08733bfce6d.png filter=lfs diff=lfs merge=lfs -text
1827
  illustrious_generated/0aec7eb07d4f.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1825
  illustrious_generated/6ea1f2330bb2.png filter=lfs diff=lfs merge=lfs -text
1826
  illustrious_generated/a08733bfce6d.png filter=lfs diff=lfs merge=lfs -text
1827
  illustrious_generated/0aec7eb07d4f.png filter=lfs diff=lfs merge=lfs -text
1828
+ illustrious_generated/699b6be34d1c.png filter=lfs diff=lfs merge=lfs -text
1829
+ illustrious_generated/b8fe03330562.png filter=lfs diff=lfs merge=lfs -text
1830
+ illustrious_generated/efb6fe0239e2.png filter=lfs diff=lfs merge=lfs -text
1831
+ illustrious_generated/a00e6e3ca42d.png filter=lfs diff=lfs merge=lfs -text
1832
+ illustrious_generated/4c39593e352f.png filter=lfs diff=lfs merge=lfs -text
1833
+ illustrious_generated/daf79f3a5a47.png filter=lfs diff=lfs merge=lfs -text
1834
+ illustrious_generated/b66afd053661.png filter=lfs diff=lfs merge=lfs -text
1835
+ illustrious_generated/8a594acaada9.png filter=lfs diff=lfs merge=lfs -text
1836
+ illustrious_generated/532b5ad4f3b4.png filter=lfs diff=lfs merge=lfs -text
1837
+ illustrious_generated/aa2403d0dbde.png filter=lfs diff=lfs merge=lfs -text
1838
+ illustrious_generated/ffa981380c46.png filter=lfs diff=lfs merge=lfs -text
1839
+ illustrious_generated/60f785243e60.png filter=lfs diff=lfs merge=lfs -text
1840
+ illustrious_generated/776bf5e2d2a1.png filter=lfs diff=lfs merge=lfs -text
1841
+ illustrious_generated/5ab3056565d7.png filter=lfs diff=lfs merge=lfs -text
1842
+ illustrious_generated/bde193a32ad5.png filter=lfs diff=lfs merge=lfs -text
1843
+ illustrious_generated/58b9009a8433.png filter=lfs diff=lfs merge=lfs -text
1844
+ illustrious_generated/db68dd8cb2cf.png filter=lfs diff=lfs merge=lfs -text
1845
+ illustrious_generated/79a4d58eda5c.png filter=lfs diff=lfs merge=lfs -text
1846
+ illustrious_generated/6c2bec98397a.png filter=lfs diff=lfs merge=lfs -text
1847
+ illustrious_generated/56b23fd84d9e.png filter=lfs diff=lfs merge=lfs -text
1848
+ illustrious_generated/e38a1e2e958f.png filter=lfs diff=lfs merge=lfs -text
1849
+ illustrious_generated/4014bca6d153.png filter=lfs diff=lfs merge=lfs -text
1850
+ illustrious_generated/754db85fe85f.png filter=lfs diff=lfs merge=lfs -text
1851
+ illustrious_generated/0a4d2753b7c7.png filter=lfs diff=lfs merge=lfs -text
1852
+ illustrious_generated/a0e55ddbf74c.png filter=lfs diff=lfs merge=lfs -text
1853
+ illustrious_generated/c472bc9db7e5.png filter=lfs diff=lfs merge=lfs -text
1854
+ illustrious_generated/07f25b42b8d7.png filter=lfs diff=lfs merge=lfs -text
1855
+ illustrious_generated/a317dff0fe2e.png filter=lfs diff=lfs merge=lfs -text
1856
+ illustrious_generated/1ca2c247cfef.png filter=lfs diff=lfs merge=lfs -text
1857
+ illustrious_generated/92e013a48101.png filter=lfs diff=lfs merge=lfs -text
1858
+ illustrious_generated/53d7fbefadd0.png filter=lfs diff=lfs merge=lfs -text
1859
+ illustrious_generated/0d1561551252.png filter=lfs diff=lfs merge=lfs -text
1860
+ illustrious_generated/13d5cc66d420.png filter=lfs diff=lfs merge=lfs -text
1861
+ illustrious_generated/7429db9846f3.png filter=lfs diff=lfs merge=lfs -text
1862
+ illustrious_generated/49222c07a03f.png filter=lfs diff=lfs merge=lfs -text
1863
+ illustrious_generated/e526540fdbb8.png filter=lfs diff=lfs merge=lfs -text
1864
+ illustrious_generated/0e42a3a7d296.png filter=lfs diff=lfs merge=lfs -text
1865
+ illustrious_generated/29259007fbbf.png filter=lfs diff=lfs merge=lfs -text
1866
+ illustrious_generated/e189b3abf6d2.png filter=lfs diff=lfs merge=lfs -text
1867
+ illustrious_generated/4af1ebb4488c.png filter=lfs diff=lfs merge=lfs -text
1868
+ illustrious_generated/1755674c7bc1.png filter=lfs diff=lfs merge=lfs -text
1869
+ illustrious_generated/9a7e94351f39.png filter=lfs diff=lfs merge=lfs -text
1870
+ illustrious_generated/548d5f09851f.png filter=lfs diff=lfs merge=lfs -text
1871
+ illustrious_generated/8d3200b9e787.png filter=lfs diff=lfs merge=lfs -text
1872
+ illustrious_generated/a3a89fac4fcf.png filter=lfs diff=lfs merge=lfs -text
1873
+ illustrious_generated/e4bf1a6be7c6.png filter=lfs diff=lfs merge=lfs -text
1874
+ illustrious_generated/37c414606b7d.png filter=lfs diff=lfs merge=lfs -text
1875
+ illustrious_generated/c4033b9e6115.png filter=lfs diff=lfs merge=lfs -text
1876
+ illustrious_generated/29c45b6176ff.png filter=lfs diff=lfs merge=lfs -text
1877
+ illustrious_generated/e729db57f47c.png filter=lfs diff=lfs merge=lfs -text
1878
+ illustrious_generated/96a62fa83b20.png filter=lfs diff=lfs merge=lfs -text
1879
+ illustrious_generated/6a51d82ae8bc.png filter=lfs diff=lfs merge=lfs -text
1880
+ illustrious_generated/4632a958366a.png filter=lfs diff=lfs merge=lfs -text
1881
+ illustrious_generated/47aa86a18b5a.png filter=lfs diff=lfs merge=lfs -text
1882
+ illustrious_generated/d1995ff196e3.png filter=lfs diff=lfs merge=lfs -text
1883
+ illustrious_generated/e5ef86680557.png filter=lfs diff=lfs merge=lfs -text
1884
+ illustrious_generated/bd545030c487.png filter=lfs diff=lfs merge=lfs -text
1885
+ illustrious_generated/cef6d8bbd520.png filter=lfs diff=lfs merge=lfs -text
1886
+ illustrious_generated/6b23dba7be7d.png filter=lfs diff=lfs merge=lfs -text
1887
+ illustrious_generated/8d3ab41288cb.png filter=lfs diff=lfs merge=lfs -text
1888
+ illustrious_generated/e31935899f08.png filter=lfs diff=lfs merge=lfs -text
1889
+ illustrious_generated/bc6707dcf1f2.png filter=lfs diff=lfs merge=lfs -text
1890
+ illustrious_generated/b3a626cfe029.png filter=lfs diff=lfs merge=lfs -text
1891
+ illustrious_generated/3503bc2ba8ab.png filter=lfs diff=lfs merge=lfs -text
1892
+ illustrious_generated/ce38aa6d180d.png filter=lfs diff=lfs merge=lfs -text
1893
+ illustrious_generated/6035404a25b1.png filter=lfs diff=lfs merge=lfs -text
1894
+ illustrious_generated/9b1bb21a46e8.png filter=lfs diff=lfs merge=lfs -text
1895
+ illustrious_generated/a4eb52675c75.png filter=lfs diff=lfs merge=lfs -text
1896
+ illustrious_generated/a7acf95bd41f.png filter=lfs diff=lfs merge=lfs -text
1897
+ illustrious_generated/c636dce88f49.png filter=lfs diff=lfs merge=lfs -text
1898
+ illustrious_generated/8ef3560e5a42.png filter=lfs diff=lfs merge=lfs -text
1899
+ illustrious_generated/4f7953e3c61d.png filter=lfs diff=lfs merge=lfs -text
1900
+ illustrious_generated/a9d16660f9be.png filter=lfs diff=lfs merge=lfs -text
1901
+ illustrious_generated/5028ea7283a1.png filter=lfs diff=lfs merge=lfs -text
1902
+ illustrious_generated/259e79d12cba.png filter=lfs diff=lfs merge=lfs -text
1903
+ illustrious_generated/aa08fc4ab5f7.png filter=lfs diff=lfs merge=lfs -text
1904
+ illustrious_generated/a048f190d616.png filter=lfs diff=lfs merge=lfs -text
1905
+ illustrious_generated/64d5e838ff32.png filter=lfs diff=lfs merge=lfs -text
1906
+ illustrious_generated/f8b258d2d426.png filter=lfs diff=lfs merge=lfs -text
1907
+ illustrious_generated/73e5f02dca91.png filter=lfs diff=lfs merge=lfs -text
1908
+ illustrious_generated/176f37a6cd7f.png filter=lfs diff=lfs merge=lfs -text
1909
+ illustrious_generated/75bc8bf6da61.png filter=lfs diff=lfs merge=lfs -text
1910
+ illustrious_generated/831ac5f05b4e.png filter=lfs diff=lfs merge=lfs -text
1911
+ illustrious_generated/6061d4cce84a.png filter=lfs diff=lfs merge=lfs -text
1912
+ illustrious_generated/1970ccc17731.png filter=lfs diff=lfs merge=lfs -text
1913
+ illustrious_generated/46d7a6324536.png filter=lfs diff=lfs merge=lfs -text
1914
+ illustrious_generated/593a5b377fb6.png filter=lfs diff=lfs merge=lfs -text
1915
+ illustrious_generated/9e588e17280b.png filter=lfs diff=lfs merge=lfs -text
1916
+ illustrious_generated/f086d40d76b7.png filter=lfs diff=lfs merge=lfs -text
1917
+ illustrious_generated/e5b4630753fb.png filter=lfs diff=lfs merge=lfs -text
1918
+ illustrious_generated/bf19c6825d9d.png filter=lfs diff=lfs merge=lfs -text
1919
+ illustrious_generated/ad90d731a0fa.png filter=lfs diff=lfs merge=lfs -text
1920
+ illustrious_generated/4c7a209b8887.png filter=lfs diff=lfs merge=lfs -text
1921
+ illustrious_generated/b52834e10fd4.png filter=lfs diff=lfs merge=lfs -text
1922
+ illustrious_generated/e0fa2a6e6b42.png filter=lfs diff=lfs merge=lfs -text
1923
+ illustrious_generated/b770edd1abe0.png filter=lfs diff=lfs merge=lfs -text
1924
+ illustrious_generated/bf8fcd7e3ebb.png filter=lfs diff=lfs merge=lfs -text
1925
+ illustrious_generated/b2a469095584.png filter=lfs diff=lfs merge=lfs -text
1926
+ illustrious_generated/4af582043c66.png filter=lfs diff=lfs merge=lfs -text
1927
+ illustrious_generated/666f1a84cb80.png filter=lfs diff=lfs merge=lfs -text
1928
+ illustrious_generated/c87663be02a7.png filter=lfs diff=lfs merge=lfs -text
1929
+ illustrious_generated/5fa4fb280d76.png filter=lfs diff=lfs merge=lfs -text
1930
+ illustrious_generated/c65bd94bf971.png filter=lfs diff=lfs merge=lfs -text
1931
+ illustrious_generated/a3f618970fb7.png filter=lfs diff=lfs merge=lfs -text
1932
+ illustrious_generated/c588cc386611.png filter=lfs diff=lfs merge=lfs -text
1933
+ illustrious_generated/e412e3f1af16.png filter=lfs diff=lfs merge=lfs -text
1934
+ illustrious_generated/acf002111806.png filter=lfs diff=lfs merge=lfs -text
1935
+ illustrious_generated/3dbcf74b53f1.png filter=lfs diff=lfs merge=lfs -text
1936
+ illustrious_generated/149a82673fe9.png filter=lfs diff=lfs merge=lfs -text
1937
+ illustrious_generated/4536122bec42.png filter=lfs diff=lfs merge=lfs -text
1938
+ illustrious_generated/25df61f584ba.png filter=lfs diff=lfs merge=lfs -text
1939
+ illustrious_generated/5803e88c45f3.png filter=lfs diff=lfs merge=lfs -text
1940
+ illustrious_generated/d5f24e4f2554.png filter=lfs diff=lfs merge=lfs -text
1941
+ illustrious_generated/57264915c81a.png filter=lfs diff=lfs merge=lfs -text
1942
+ illustrious_generated/b5a6eaf580ce.png filter=lfs diff=lfs merge=lfs -text
1943
+ illustrious_generated/18f1f90221cc.png filter=lfs diff=lfs merge=lfs -text
1944
+ illustrious_generated/798df3adbd88.png filter=lfs diff=lfs merge=lfs -text
1945
+ illustrious_generated/617ce74ffafa.png filter=lfs diff=lfs merge=lfs -text
1946
+ illustrious_generated/1dfed5dbd442.png filter=lfs diff=lfs merge=lfs -text
1947
+ illustrious_generated/243fd5bd7679.png filter=lfs diff=lfs merge=lfs -text
1948
+ illustrious_generated/2433395d6d43.png filter=lfs diff=lfs merge=lfs -text
1949
+ illustrious_generated/96ab8498c7a6.png filter=lfs diff=lfs merge=lfs -text
1950
+ illustrious_generated/4a574003bbc9.png filter=lfs diff=lfs merge=lfs -text
1951
+ illustrious_generated/d209e317052f.png filter=lfs diff=lfs merge=lfs -text
1952
+ illustrious_generated/fb9ed29c63d7.png filter=lfs diff=lfs merge=lfs -text
1953
+ illustrious_generated/ece3750ad6a6.png filter=lfs diff=lfs merge=lfs -text
1954
+ illustrious_generated/690e7719d4bf.png filter=lfs diff=lfs merge=lfs -text
1955
+ illustrious_generated/639cf9531d64.png filter=lfs diff=lfs merge=lfs -text
1956
+ illustrious_generated/56e66190e7ce.png filter=lfs diff=lfs merge=lfs -text
1957
+ illustrious_generated/ceb937e0801e.png filter=lfs diff=lfs merge=lfs -text
1958
+ illustrious_generated/ed0e6ab9592e.png filter=lfs diff=lfs merge=lfs -text
1959
+ illustrious_generated/312f15890343.png filter=lfs diff=lfs merge=lfs -text
1960
+ illustrious_generated/d21efb5f15d1.png filter=lfs diff=lfs merge=lfs -text
1961
+ illustrious_generated/b01a98afee9a.png filter=lfs diff=lfs merge=lfs -text
1962
+ illustrious_generated/3cbd0d9a36a7.png filter=lfs diff=lfs merge=lfs -text
1963
+ illustrious_generated/6994acf3db4f.png filter=lfs diff=lfs merge=lfs -text
1964
+ illustrious_generated/b5ee5bce0095.png filter=lfs diff=lfs merge=lfs -text
1965
+ illustrious_generated/e150c196c1de.png filter=lfs diff=lfs merge=lfs -text
1966
+ illustrious_generated/d0d5e2ddb402.png filter=lfs diff=lfs merge=lfs -text
1967
+ illustrious_generated/b102a265b15d.png filter=lfs diff=lfs merge=lfs -text
1968
+ illustrious_generated/b0202f93d233.png filter=lfs diff=lfs merge=lfs -text
1969
+ illustrious_generated/14cde88e8c02.png filter=lfs diff=lfs merge=lfs -text
1970
+ illustrious_generated/30f0d4a3b6b6.png filter=lfs diff=lfs merge=lfs -text
1971
+ illustrious_generated/33aeb710e683.png filter=lfs diff=lfs merge=lfs -text
1972
+ illustrious_generated/71485a12c097.png filter=lfs diff=lfs merge=lfs -text
1973
+ illustrious_generated/2c3ca96d94bc.png filter=lfs diff=lfs merge=lfs -text
1974
+ illustrious_generated/a630cf5674f3.png filter=lfs diff=lfs merge=lfs -text
1975
+ illustrious_generated/163359d65694.png filter=lfs diff=lfs merge=lfs -text
1976
+ illustrious_generated/2f054f26c97c.png filter=lfs diff=lfs merge=lfs -text
1977
+ illustrious_generated/af2a10d42166.png filter=lfs diff=lfs merge=lfs -text
1978
+ illustrious_generated/4d7d6abe842c.png filter=lfs diff=lfs merge=lfs -text
1979
+ illustrious_generated/62fb4f26c8d0.png filter=lfs diff=lfs merge=lfs -text
1980
+ illustrious_generated/8bc741c00933.png filter=lfs diff=lfs merge=lfs -text
1981
+ illustrious_generated/f3021039e6df.png filter=lfs diff=lfs merge=lfs -text
1982
+ illustrious_generated/7ed913ac6954.png filter=lfs diff=lfs merge=lfs -text
1983
+ illustrious_generated/a9ab451bd4c9.png filter=lfs diff=lfs merge=lfs -text
1984
+ illustrious_generated/49c90b238c77.png filter=lfs diff=lfs merge=lfs -text
1985
+ illustrious_generated/1371e80e6f16.png filter=lfs diff=lfs merge=lfs -text
1986
+ illustrious_generated/f284fb62c9bf.png filter=lfs diff=lfs merge=lfs -text
1987
+ illustrious_generated/71e9866c5494.png filter=lfs diff=lfs merge=lfs -text
1988
+ illustrious_generated/abf601407cb0.png filter=lfs diff=lfs merge=lfs -text
1989
+ illustrious_generated/83da1d18e129.png filter=lfs diff=lfs merge=lfs -text
1990
+ illustrious_generated/484afd596fd0.png filter=lfs diff=lfs merge=lfs -text
1991
+ illustrious_generated/4c0f1f03dcd2.png filter=lfs diff=lfs merge=lfs -text
1992
+ illustrious_generated/2008ee2e6f66.png filter=lfs diff=lfs merge=lfs -text
1993
+ illustrious_generated/48fcc54ce758.png filter=lfs diff=lfs merge=lfs -text
1994
+ illustrious_generated/4769b078890d.png filter=lfs diff=lfs merge=lfs -text
1995
+ illustrious_generated/b716f3a23f11.png filter=lfs diff=lfs merge=lfs -text
1996
+ illustrious_generated/8e740eadaf18.png filter=lfs diff=lfs merge=lfs -text
1997
+ illustrious_generated/88e259bbe761.png filter=lfs diff=lfs merge=lfs -text
1998
+ illustrious_generated/476617799f8d.png filter=lfs diff=lfs merge=lfs -text
1999
+ illustrious_generated/1bb92008067a.png filter=lfs diff=lfs merge=lfs -text
2000
+ illustrious_generated/38cdd5aae9bb.png filter=lfs diff=lfs merge=lfs -text
2001
+ illustrious_generated/5e8555d59ee2.png filter=lfs diff=lfs merge=lfs -text
2002
+ illustrious_generated/f3f41fd3689e.png filter=lfs diff=lfs merge=lfs -text
2003
+ illustrious_generated/ca3147d568a3.png filter=lfs diff=lfs merge=lfs -text
2004
+ illustrious_generated/8e375573fe30.png filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (28.2 kB). View file
 
.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-312.pyc ADDED
Binary file (12.5 kB). View file
 
.venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (192 Bytes). View file
 
.venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/python.cpython-312.pyc ADDED
Binary file (11.5 kB). View file
 
.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (14 kB). View file
 
.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/config.cpython-312.pyc ADDED
Binary file (41.4 kB). View file
 
.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-312.pyc ADDED
Binary file (17.5 kB). View file
 
.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/custom_graph_pass.cpython-312.pyc ADDED
Binary file (5.43 kB). View file
 
.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/standalone_compile.cpython-312.pyc ADDED
Binary file (12 kB). View file
 
.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_operators.cpython-312.pyc ADDED
Binary file (2.05 kB). View file
 
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__init__.py ADDED
File without changes
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: B950
2
+ # fmt: off
3
+ # This file was generated by AutoHeuristic. Do not modify it manually!
4
+ # To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/
5
+ from typing import List, Optional, Tuple
6
+
7
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
8
+ AHContext,
9
+ AHMetadata,
10
+ Choice,
11
+ )
12
+ from torch._inductor.autoheuristic.learnedheuristic_interface import (
13
+ LearnedHeuristicDecision,
14
+ )
15
+
16
+
17
+ class MMRankingA100(LearnedHeuristicDecision):
18
+
19
+ def __init__(self) -> None:
20
+ self.choices: List[Choice] = []
21
+ self.fill_choices()
22
+
23
+ def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
24
+ return (
25
+ metadata.name == self.get_name()
26
+ and metadata.shared_memory == 166912
27
+ and str(metadata.device_capa) == "(8, 0)"
28
+ )
29
+
30
+ def get_confidence_threshold(self) -> float:
31
+ return 0.0
32
+
33
+ def get_choice(self, idx: int) -> Optional[str]:
34
+ if idx < len(self.choices):
35
+ return self.choices[idx]
36
+ return None
37
+
38
+ def fill_choices(self) -> None:
39
+ self.choices.append('extern_mm')
40
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8')
41
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8')
42
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
43
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
44
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
45
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
46
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
47
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
48
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
49
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
50
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8')
51
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
52
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8')
53
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
54
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8')
55
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
56
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8')
57
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
58
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8')
59
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
60
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8')
61
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
62
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
63
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
64
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
65
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
66
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8')
67
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
68
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
69
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
70
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
71
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
72
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
73
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
74
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
75
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8')
76
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
77
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4')
78
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8')
79
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8')
80
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
81
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8')
82
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
83
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8')
84
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
85
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8')
86
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2')
87
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8')
88
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
89
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
90
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
91
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
92
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
93
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
94
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2')
95
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8')
96
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
97
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
98
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
99
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
100
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
101
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
102
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
103
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
104
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
105
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
106
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
107
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8')
108
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8')
109
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
110
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8')
111
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
112
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
113
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8')
114
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
115
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8')
116
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
117
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
118
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8')
119
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
120
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8')
121
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
122
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
123
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
124
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
125
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
126
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
127
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
128
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
129
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
130
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
131
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1')
132
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
133
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
134
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2')
135
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2')
136
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2')
137
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
138
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
139
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
140
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
141
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
142
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1')
143
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2')
144
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
145
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
146
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
147
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
148
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
149
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=2')
150
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
151
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
152
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
153
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
154
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2')
155
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
156
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
157
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
158
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
159
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
160
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2')
161
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
162
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2')
163
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
164
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
165
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
166
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
167
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
168
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2')
169
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
170
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
171
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
172
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4')
173
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
174
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4')
175
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4')
176
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4')
177
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4')
178
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4')
179
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
180
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
181
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
182
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
183
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
184
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
185
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
186
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
187
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4')
188
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
189
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
190
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
191
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
192
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
193
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8')
194
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
195
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
196
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
197
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
198
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
199
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
200
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
201
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
202
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
203
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
204
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
205
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
206
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
207
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4')
208
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
209
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
210
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
211
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
212
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
213
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
214
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
215
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
216
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
217
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
218
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
219
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
220
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
221
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
222
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
223
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
224
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
225
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4')
226
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
227
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4')
228
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
229
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
230
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
231
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4')
232
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
233
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
234
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
235
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
236
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
237
+
238
+ def get_name(self) -> str:
239
+ return 'mm'
240
+
241
+ def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]:
242
+ if context.get_value('arith_intensity') <= 52.6245059967041:
243
+ if context.get_value('n') <= 34.0:
244
+ if context.get_value('n') <= 18.0:
245
+ if context.get_value('k*n') <= 312.0:
246
+ return [(0.093, 12), (0.081, 16), (0.081, 148), (0.070, 10), (0.070, 17), (0.070, 149), (0.070, 151), (0.070, 150), (0.070, 14), (0.058, 11), (0.058, 15), (0.058, 13), (0.058, 122), (0.047, 121), (0.035, 123), (0.012, 92)]
247
+ else:
248
+ if context.get_value('k') <= 40.0:
249
+ return [(0.083, 42), (0.083, 46), (0.083, 44), (0.083, 40), (0.083, 128), (0.067, 45), (0.067, 43), (0.067, 41), (0.067, 169), (0.067, 171), (0.067, 168), (0.067, 129), (0.067, 170), (0.033, 103), (0.017, 121)]
250
+ else:
251
+ return [(0.112, 137), (0.104, 136), (0.101, 0), (0.081, 1), (0.073, 135), (0.069, 67), (0.066, 187), (0.058, 41), (0.050, 71), (0.046, 68), (0.046, 70), (0.031, 44), (0.027, 43), (0.027, 170), (0.019, 189), (0.019, 188), (0.015, 169), (0.015, 171), (0.012, 115), (0.012, 168), (0.012, 69), (0.004, 103)]
252
+ else:
253
+ if context.get_value('mat1_stride_0') <= 20.0:
254
+ return [(0.069, 0), (0.059, 157), (0.059, 22), (0.059, 153), (0.059, 155), (0.059, 25), (0.059, 23), (0.059, 19), (0.044, 21), (0.044, 18), (0.044, 152), (0.044, 158), (0.044, 154), (0.044, 156), (0.044, 20), (0.044, 124), (0.044, 24), (0.030, 125), (0.029, 126), (0.015, 97), (0.015, 95), (0.015, 96), (0.010, 2), (0.010, 75)]
255
+ else:
256
+ if context.get_value('k') <= 68.0:
257
+ return [(0.087, 72), (0.087, 74), (0.087, 73), (0.086, 76), (0.077, 75), (0.067, 192), (0.058, 190), (0.048, 47), (0.048, 193), (0.048, 49), (0.048, 51), (0.048, 191), (0.038, 53), (0.019, 133), (0.019, 50), (0.019, 175), (0.019, 172), (0.019, 48), (0.019, 174), (0.010, 173), (0.010, 177), (0.010, 52), (0.010, 54), (0.010, 178), (0.010, 176)]
258
+ else:
259
+ return [(0.154, 52), (0.154, 72), (0.102, 75), (0.087, 49), (0.087, 73), (0.086, 51), (0.057, 176), (0.045, 2), (0.038, 191), (0.038, 178), (0.038, 190), (0.029, 173), (0.029, 76), (0.026, 138), (0.013, 139), (0.013, 140), (0.003, 0)]
260
+ else:
261
+ if context.get_value('k') <= 35.0:
262
+ if context.get_value('k') <= 18.0:
263
+ if context.get_value('m*n') <= 19505152.0:
264
+ return [(0.151, 159), (0.140, 160), (0.129, 164), (0.055, 127), (0.051, 29), (0.044, 161), (0.044, 147), (0.040, 146), (0.040, 31), (0.037, 145), (0.026, 28), (0.022, 90), (0.022, 93), (0.022, 94), (0.022, 100), (0.022, 125), (0.022, 158), (0.022, 157), (0.011, 87), (0.011, 88), (0.011, 89), (0.011, 91), (0.011, 95), (0.011, 96), (0.011, 98), (0.011, 99)]
265
+ else:
266
+ return [(0.069, 7), (0.069, 5), (0.067, 147), (0.066, 8), (0.061, 145), (0.058, 146), (0.052, 124), (0.049, 29), (0.049, 159), (0.046, 31), (0.043, 157), (0.041, 9), (0.041, 4), (0.040, 6), (0.035, 164), (0.035, 160), (0.026, 158), (0.017, 125), (0.017, 28), (0.017, 32), (0.017, 162), (0.017, 27), (0.017, 30), (0.017, 161), (0.009, 33), (0.009, 26), (0.009, 163), (0.006, 0)]
267
+ else:
268
+ if context.get_value('n') <= 68.0:
269
+ return [(0.101, 182), (0.101, 59), (0.088, 57), (0.076, 184), (0.076, 61), (0.076, 179), (0.076, 62), (0.076, 58), (0.063, 180), (0.063, 60), (0.051, 56), (0.050, 181), (0.025, 130), (0.025, 177), (0.025, 183), (0.013, 178), (0.013, 55)]
270
+ else:
271
+ return [(0.089, 180), (0.079, 60), (0.066, 35), (0.066, 181), (0.066, 38), (0.066, 58), (0.066, 179), (0.066, 57), (0.062, 184), (0.053, 37), (0.044, 166), (0.040, 55), (0.040, 39), (0.040, 36), (0.040, 165), (0.040, 167), (0.027, 177), (0.027, 34), (0.022, 159)]
272
+ else:
273
+ if context.get_value('m*n') <= 309760.0:
274
+ return [(0.298, 0), (0.097, 140), (0.080, 83), (0.072, 86), (0.044, 84), (0.036, 178), (0.036, 117), (0.036, 82), (0.032, 120), (0.032, 85), (0.028, 119), (0.024, 130), (0.024, 109), (0.020, 108), (0.020, 118), (0.012, 104), (0.012, 116), (0.012, 141), (0.012, 144), (0.008, 105), (0.008, 106), (0.008, 111), (0.008, 114), (0.008, 107), (0.008, 132), (0.004, 101), (0.004, 102), (0.004, 110), (0.004, 112), (0.004, 113), (0.004, 131)]
275
+ else:
276
+ if context.get_value('n') <= 72.0:
277
+ return [(0.227, 77), (0.118, 78), (0.102, 194), (0.086, 80), (0.059, 57), (0.054, 81), (0.049, 196), (0.048, 197), (0.048, 59), (0.043, 79), (0.032, 195), (0.027, 180), (0.022, 3), (0.021, 141), (0.016, 60), (0.016, 142), (0.011, 183), (0.011, 0), (0.011, 144)]
278
+ else:
279
+ return [(0.140, 186), (0.132, 185), (0.109, 63), (0.085, 65), (0.078, 37), (0.077, 35), (0.062, 197), (0.047, 194), (0.046, 165), (0.046, 57), (0.039, 78), (0.039, 79), (0.039, 66), (0.039, 64), (0.016, 195), (0.008, 159)]
280
+ else:
281
+ if str(context.get_value('using_tf32')) != 'False':
282
+ if context.get_value('m*n') <= 815360.0:
283
+ if context.get_value('k') <= 1184.0:
284
+ return [(0.218, 140), (0.205, 0), (0.154, 144), (0.115, 141), (0.051, 185), (0.051, 104), (0.039, 78), (0.038, 116), (0.026, 165), (0.026, 130), (0.026, 178), (0.013, 57), (0.013, 195), (0.013, 167), (0.013, 186)]
285
+ else:
286
+ return [(0.901, 0), (0.030, 144), (0.030, 134), (0.016, 3), (0.006, 78), (0.006, 77), (0.002, 57), (0.002, 194), (0.002, 59), (0.002, 60), (0.002, 143)]
287
+ else:
288
+ if context.get_value('arith_intensity') <= 187.23922729492188:
289
+ if context.get_value('mat1_stride_0') <= 198.0:
290
+ return [(0.273, 63), (0.158, 37), (0.152, 35), (0.127, 57), (0.097, 165), (0.053, 185), (0.031, 0), (0.028, 64), (0.014, 60), (0.014, 78), (0.009, 55), (0.008, 134), (0.005, 34), (0.005, 167), (0.005, 179), (0.005, 65), (0.005, 66), (0.005, 186), (0.005, 194), (0.002, 166)]
291
+ else:
292
+ return [(0.296, 63), (0.235, 0), (0.132, 64), (0.074, 37), (0.069, 78), (0.051, 185), (0.051, 35), (0.030, 57), (0.020, 77), (0.016, 194), (0.008, 66), (0.007, 65), (0.003, 3), (0.003, 165), (0.003, 141), (0.001, 134), (0.001, 166)]
293
+ else:
294
+ return [(0.405, 0), (0.246, 37), (0.177, 63), (0.145, 35), (0.005, 185), (0.005, 65), (0.005, 64), (0.004, 57), (0.003, 66), (0.002, 165), (0.001, 78), (0.001, 55)]
295
+ else:
296
+ return [(0.357, 0), (0.112, 165), (0.101, 57), (0.094, 179), (0.086, 64), (0.074, 167), (0.067, 60), (0.064, 159), (0.033, 35), (0.007, 195), (0.002, 180), (0.001, 34), (0.001, 166), (0.001, 78)]
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: B950
2
+ # fmt: off
3
+ # This file was generated by AutoHeuristic. Do not modify it manually!
4
+ # To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/
5
+ from typing import List, Optional, Tuple
6
+
7
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
8
+ AHContext,
9
+ AHMetadata,
10
+ Choice,
11
+ )
12
+ from torch._inductor.autoheuristic.learnedheuristic_interface import (
13
+ LearnedHeuristicDecision,
14
+ )
15
+
16
+
17
+ class MMRankingH100(LearnedHeuristicDecision):
18
+
19
+ def __init__(self) -> None:
20
+ self.choices: List[Choice] = []
21
+ self.fill_choices()
22
+
23
+ def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
24
+ return (
25
+ metadata.name == self.get_name()
26
+ and metadata.shared_memory == 232448
27
+ and str(metadata.device_capa) == "(9, 0)"
28
+ )
29
+
30
+ def get_confidence_threshold(self) -> float:
31
+ return 0.0
32
+
33
+ def get_choice(self, idx: int) -> Optional[str]:
34
+ if idx < len(self.choices):
35
+ return self.choices[idx]
36
+ return None
37
+
38
+ def fill_choices(self) -> None:
39
+ self.choices.append('extern_mm')
40
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8')
41
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8')
42
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
43
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
44
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
45
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
46
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
47
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
48
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
49
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
50
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8')
51
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
52
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8')
53
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
54
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8')
55
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
56
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8')
57
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
58
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8')
59
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
60
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
61
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
62
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
63
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
64
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8')
65
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
66
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
67
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
68
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
69
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
70
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
71
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
72
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
73
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8')
74
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
75
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4')
76
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8')
77
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
78
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8')
79
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
80
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8')
81
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
82
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8')
83
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
84
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8')
85
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2')
86
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8')
87
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
88
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
89
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
90
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
91
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
92
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
93
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2')
94
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8')
95
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
96
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
97
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
98
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
99
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
100
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
101
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
102
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
103
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
104
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
105
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
106
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8')
107
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8')
108
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
109
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8')
110
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
111
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
112
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8')
113
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
114
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8')
115
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
116
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
117
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8')
118
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
119
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8')
120
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
121
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
122
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
123
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
124
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
125
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
126
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
127
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
128
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
129
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
130
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
131
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
132
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
133
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=1')
134
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=1')
135
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1')
136
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
137
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
138
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2')
139
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2')
140
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2')
141
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
142
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
143
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
144
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
145
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
146
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
147
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
148
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
149
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
150
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=1')
151
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1')
152
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=2')
153
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2')
154
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
155
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
156
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
157
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
158
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
159
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
160
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
161
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
162
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2')
163
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
164
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
165
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
166
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2')
167
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
168
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2')
169
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
170
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
171
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
172
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
173
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
174
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2')
175
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
176
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
177
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
178
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=16_numstages=2_numwarps=2')
179
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4')
180
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
181
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4')
182
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4')
183
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4')
184
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4')
185
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4')
186
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
187
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
188
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
189
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
190
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
191
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
192
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
193
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4')
194
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
195
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
196
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
197
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
198
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
199
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
200
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
201
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
202
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
203
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
204
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
205
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
206
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
207
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
208
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
209
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
210
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
211
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4')
212
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
213
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
214
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
215
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
216
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
217
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
218
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
219
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
220
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
221
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
222
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
223
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
224
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
225
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
226
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
227
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
228
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
229
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4')
230
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
231
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4')
232
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
233
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
234
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
235
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4')
236
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
237
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
238
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
239
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
240
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
241
+
242
+ def get_name(self) -> str:
243
+ return 'mm'
244
+
245
+ def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]:
246
+ if context.get_value('arith_intensity') <= 29.89772129058838:
247
+ if context.get_value('n') <= 34.0:
248
+ if context.get_value('n') <= 18.0:
249
+ if context.get_value('k*n') <= 432.0:
250
+ if context.get_value('arith_intensity') <= 7.8700292110443115:
251
+ return [(0.098, 128), (0.098, 129), (0.098, 127), (0.073, 14), (0.073, 16), (0.073, 12), (0.073, 154), (0.073, 156), (0.073, 157), (0.073, 155), (0.049, 10), (0.049, 94), (0.049, 95), (0.048, 96)]
252
+ else:
253
+ return [(0.091, 154), (0.073, 10), (0.073, 15), (0.073, 13), (0.073, 11), (0.073, 17), (0.073, 16), (0.073, 14), (0.073, 12), (0.055, 127), (0.054, 157), (0.054, 156), (0.054, 155), (0.036, 129), (0.036, 128), (0.018, 41), (0.018, 43)]
254
+ else:
255
+ if context.get_value('k') <= 40.0:
256
+ return [(0.070, 39), (0.069, 45), (0.069, 41), (0.069, 43), (0.069, 111), (0.069, 112), (0.056, 38), (0.056, 40), (0.056, 42), (0.056, 44), (0.056, 174), (0.056, 173), (0.056, 175), (0.056, 134), (0.056, 172), (0.056, 135), (0.014, 154), (0.014, 127)]
257
+ else:
258
+ return [(0.147, 144), (0.119, 143), (0.087, 142), (0.083, 0), (0.073, 191), (0.059, 69), (0.050, 67), (0.046, 70), (0.041, 1), (0.036, 174), (0.032, 43), (0.032, 123), (0.028, 40), (0.027, 42), (0.027, 173), (0.023, 175), (0.018, 66), (0.014, 192), (0.014, 193), (0.014, 139), (0.014, 68), (0.014, 127)]
259
+ else:
260
+ if context.get_value('mat1_stride_0') <= 40.0:
261
+ if context.get_value('mat1_stride_0') <= 20.0:
262
+ return [(0.109, 23), (0.109, 21), (0.109, 20), (0.088, 0), (0.087, 131), (0.066, 18), (0.065, 130), (0.065, 132), (0.065, 159), (0.065, 160), (0.065, 161), (0.065, 158), (0.022, 22), (0.022, 19)]
263
+ else:
264
+ return [(0.065, 46), (0.064, 52), (0.064, 50), (0.064, 48), (0.064, 51), (0.064, 49), (0.064, 47), (0.064, 53), (0.064, 181), (0.064, 177), (0.064, 179), (0.064, 176), (0.038, 130), (0.038, 136), (0.026, 182), (0.026, 178), (0.026, 180), (0.026, 137), (0.025, 158), (0.013, 114), (0.013, 113)]
265
+ else:
266
+ if context.get_value('mat1_stride_0') <= 68.0:
267
+ return [(0.138, 140), (0.125, 195), (0.100, 71), (0.100, 74), (0.100, 196), (0.100, 194), (0.100, 197), (0.075, 75), (0.062, 72), (0.062, 73), (0.012, 180), (0.012, 51), (0.012, 182)]
268
+ else:
269
+ return [(0.124, 180), (0.124, 182), (0.114, 75), (0.103, 74), (0.093, 51), (0.093, 71), (0.072, 72), (0.062, 194), (0.052, 145), (0.052, 195), (0.021, 48), (0.021, 50), (0.021, 47), (0.020, 124), (0.010, 147), (0.010, 146), (0.010, 46)]
270
+ else:
271
+ if context.get_value('k') <= 18.0:
272
+ if context.get_value('m*k') <= 528.0:
273
+ return [(0.097, 88), (0.087, 92), (0.077, 90), (0.058, 105), (0.058, 103), (0.058, 104), (0.058, 99), (0.058, 100), (0.058, 106), (0.058, 93), (0.057, 91), (0.057, 97), (0.057, 98), (0.057, 101), (0.048, 102), (0.029, 87), (0.029, 89)]
274
+ else:
275
+ if context.get_value('n') <= 80.0:
276
+ return [(0.057, 161), (0.057, 130), (0.057, 24), (0.056, 164), (0.056, 163), (0.056, 166), (0.056, 168), (0.056, 30), (0.056, 28), (0.056, 26), (0.056, 25), (0.056, 27), (0.056, 29), (0.056, 31), (0.042, 131), (0.028, 99), (0.028, 101), (0.028, 100), (0.028, 167), (0.028, 165), (0.028, 133)]
277
+ else:
278
+ return [(0.110, 164), (0.108, 163), (0.106, 168), (0.069, 161), (0.066, 151), (0.060, 152), (0.055, 165), (0.050, 27), (0.050, 29), (0.048, 131), (0.043, 153), (0.037, 133), (0.037, 130), (0.028, 8), (0.028, 5), (0.027, 7), (0.026, 26), (0.016, 162), (0.012, 9), (0.007, 4), (0.005, 100), (0.005, 6), (0.005, 24)]
279
+ else:
280
+ if context.get_value('k') <= 36.0:
281
+ if context.get_value('n') <= 68.0:
282
+ return [(0.097, 184), (0.097, 56), (0.086, 186), (0.086, 183), (0.086, 188), (0.086, 58), (0.086, 60), (0.065, 54), (0.043, 187), (0.043, 185), (0.043, 57), (0.043, 61), (0.032, 55), (0.032, 130), (0.032, 59), (0.011, 181), (0.011, 163), (0.011, 136), (0.011, 138)]
283
+ else:
284
+ return [(0.117, 184), (0.117, 170), (0.117, 169), (0.107, 183), (0.106, 188), (0.075, 181), (0.064, 130), (0.064, 56), (0.053, 171), (0.032, 57), (0.032, 59), (0.032, 185), (0.011, 163), (0.011, 32), (0.011, 37), (0.011, 34), (0.011, 33), (0.011, 35), (0.011, 36), (0.011, 54)]
285
+ else:
286
+ if context.get_value('mat2_stride_0') <= 384.0:
287
+ return [(0.244, 0), (0.061, 76), (0.061, 79), (0.030, 3), (0.030, 183), (0.030, 189), (0.030, 187), (0.030, 64), (0.030, 190), (0.030, 62), (0.030, 198), (0.030, 201), (0.030, 77), (0.030, 200), (0.030, 80), (0.030, 199), (0.030, 78), (0.030, 184), (0.020, 86), (0.020, 84), (0.020, 120), (0.020, 81), (0.020, 121), (0.020, 85), (0.020, 122), (0.010, 83), (0.010, 118), (0.010, 119), (0.010, 82)]
288
+ else:
289
+ return [(0.274, 83), (0.171, 86), (0.152, 0), (0.071, 85), (0.061, 125), (0.050, 84), (0.020, 109), (0.020, 117), (0.020, 81), (0.020, 118), (0.020, 121), (0.020, 108), (0.020, 115), (0.020, 116), (0.010, 110), (0.010, 120), (0.010, 103), (0.010, 107), (0.010, 119), (0.010, 122)]
290
+ else:
291
+ if context.get_value('arith_intensity') <= 56.995582580566406:
292
+ if context.get_value('n') <= 68.0:
293
+ if context.get_value('k*n') <= 4448.0:
294
+ if context.get_value('m*n') <= 29626368.0:
295
+ return [(0.107, 198), (0.107, 200), (0.107, 201), (0.107, 199), (0.106, 76), (0.106, 79), (0.064, 197), (0.063, 56), (0.043, 184), (0.043, 187), (0.042, 80), (0.042, 77), (0.042, 183), (0.021, 78)]
296
+ else:
297
+ return [(0.073, 201), (0.073, 198), (0.073, 200), (0.073, 199), (0.073, 197), (0.073, 56), (0.073, 58), (0.073, 79), (0.073, 76), (0.072, 59), (0.072, 78), (0.072, 77), (0.072, 80), (0.018, 184), (0.018, 55), (0.018, 54)]
298
+ else:
299
+ if context.get_value('k') <= 348.0:
300
+ return [(0.206, 76), (0.183, 77), (0.169, 198), (0.160, 199), (0.053, 59), (0.046, 56), (0.038, 3), (0.030, 148), (0.030, 58), (0.030, 187), (0.023, 184), (0.015, 0), (0.008, 55), (0.008, 54)]
301
+ else:
302
+ return [(0.146, 198), (0.145, 199), (0.145, 148), (0.126, 0), (0.084, 76), (0.084, 77), (0.042, 80), (0.042, 79), (0.021, 149), (0.021, 150), (0.021, 3), (0.014, 46), (0.014, 74), (0.014, 75), (0.014, 124), (0.014, 194), (0.014, 195), (0.007, 145), (0.007, 146), (0.007, 2), (0.007, 72), (0.007, 147), (0.007, 71)]
303
+ else:
304
+ if context.get_value('m') <= 3264.0:
305
+ return [(0.247, 147), (0.115, 197), (0.066, 199), (0.066, 201), (0.066, 198), (0.049, 0), (0.049, 169), (0.049, 171), (0.033, 140), (0.033, 125), (0.033, 114), (0.016, 126), (0.016, 183), (0.016, 184), (0.016, 185), (0.016, 182), (0.016, 188), (0.016, 78), (0.016, 148), (0.016, 138), (0.016, 77), (0.016, 56), (0.016, 59)]
306
+ else:
307
+ if context.get_value('k') <= 62.5:
308
+ return [(0.226, 190), (0.226, 189), (0.122, 62), (0.122, 64), (0.055, 77), (0.055, 78), (0.037, 198), (0.036, 201), (0.036, 33), (0.024, 163), (0.018, 56), (0.018, 35), (0.018, 169), (0.006, 171)]
309
+ else:
310
+ return [(0.162, 35), (0.118, 33), (0.096, 189), (0.096, 190), (0.088, 169), (0.074, 62), (0.073, 56), (0.066, 171), (0.051, 198), (0.051, 201), (0.044, 59), (0.037, 64), (0.029, 63), (0.007, 0), (0.007, 77)]
311
+ else:
312
+ if context.get_value('m*n') <= 1097728.0:
313
+ return [(0.403, 0), (0.179, 141), (0.134, 150), (0.086, 147), (0.051, 148), (0.048, 3), (0.024, 189), (0.020, 199), (0.017, 64), (0.010, 65), (0.010, 77), (0.007, 114), (0.003, 138), (0.003, 59), (0.003, 182)]
314
+ else:
315
+ if context.get_value('m*n') <= 3244032.0:
316
+ return [(0.295, 189), (0.176, 64), (0.157, 65), (0.090, 0), (0.069, 62), (0.059, 63), (0.046, 77), (0.039, 169), (0.023, 199), (0.020, 35), (0.013, 33), (0.010, 171), (0.003, 141)]
317
+ else:
318
+ if context.get_value('n') <= 136.0:
319
+ return [(0.197, 189), (0.197, 63), (0.161, 77), (0.157, 62), (0.061, 33), (0.044, 65), (0.039, 35), (0.039, 64), (0.030, 169), (0.026, 0), (0.017, 199), (0.017, 148), (0.009, 56), (0.004, 3)]
320
+ else:
321
+ return [(0.460, 0), (0.145, 62), (0.138, 63), (0.081, 35), (0.047, 33), (0.043, 189), (0.023, 64), (0.018, 77), (0.013, 169), (0.009, 65), (0.009, 56), (0.005, 32), (0.005, 59), (0.002, 183), (0.002, 163)]
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: B950
2
+ # fmt: off
3
+ # This file was generated by AutoHeuristic. Do not modify it manually!
4
+ # To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/
5
+ from typing import List, Optional, Tuple
6
+
7
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
8
+ AHContext,
9
+ AHMetadata,
10
+ Choice,
11
+ )
12
+ from torch._inductor.autoheuristic.learnedheuristic_interface import (
13
+ LearnedHeuristicDecision,
14
+ )
15
+
16
+
17
+ class MixedMMA100(LearnedHeuristicDecision):
18
+
19
+ def __init__(self) -> None:
20
+ self.choices: List[Choice] = []
21
+ self.fill_choices()
22
+
23
+ def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
24
+ return (
25
+ metadata.name == self.get_name()
26
+ and metadata.shared_memory == 166912
27
+ and str(metadata.device_capa) == "(8, 0)"
28
+ )
29
+
30
+ def get_confidence_threshold(self) -> float:
31
+ return 0.0
32
+
33
+ def get_choice(self, idx: int) -> Optional[str]:
34
+ if idx < len(self.choices):
35
+ return self.choices[idx]
36
+ return None
37
+
38
+ def fill_choices(self) -> None:
39
+ self.choices.append('extern_fallback_mixed_mm')
40
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
41
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
42
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
43
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
44
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
45
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
46
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4')
47
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8')
48
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
49
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
50
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
51
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
52
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
53
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
54
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
55
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
56
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
57
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
58
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
59
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
60
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
61
+
62
+ def get_name(self) -> str:
63
+ return 'mixed_mm'
64
+
65
+ def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]:
66
+ if str(context.get_value('1LEQmLEQ16')) != 'True':
67
+ if context.get_value('m') <= 32.5:
68
+ if context.get_value('n') <= 6976.0:
69
+ if context.get_value('n') <= 3520.0:
70
+ if context.get_value('m*n') <= 37632.0:
71
+ return None
72
+ else:
73
+ return [(1.000, 13)]
74
+ else:
75
+ if context.get_value('m*k') <= 452352.0:
76
+ return [(0.590, 13), (0.256, 8), (0.103, 7), (0.051, 11)]
77
+ else:
78
+ return [(0.778, 8), (0.222, 13)]
79
+ else:
80
+ if context.get_value('k*n') <= 102776832.0:
81
+ if context.get_value('n') <= 14656.0:
82
+ return [(1.000, 11)]
83
+ else:
84
+ return [(0.889, 11), (0.111, 13)]
85
+ else:
86
+ return [(1.000, 11)]
87
+ else:
88
+ if context.get_value('m*n') <= 446464.0:
89
+ if context.get_value('m*n') <= 223424.0:
90
+ if context.get_value('mat1_stride_0') <= 3968.0:
91
+ return None
92
+ else:
93
+ return None
94
+ else:
95
+ if context.get_value('m*n') <= 346112.0:
96
+ return [(0.960, 16), (0.040, 7)]
97
+ else:
98
+ return [(0.750, 16), (0.136, 14), (0.114, 7)]
99
+ else:
100
+ if str(context.get_value('33LEQmLEQ64')) != 'True':
101
+ if context.get_value('n') <= 6976.0:
102
+ return [(1.000, 14)]
103
+ else:
104
+ return [(0.753, 2), (0.222, 1), (0.015, 7), (0.007, 16), (0.004, 12)]
105
+ else:
106
+ if context.get_value('n') <= 13888.0:
107
+ return [(0.710, 14), (0.275, 21), (0.014, 12)]
108
+ else:
109
+ return [(0.374, 19), (0.339, 20), (0.106, 21), (0.101, 16), (0.066, 17), (0.009, 14), (0.004, 18)]
110
+ else:
111
+ if context.get_value('n') <= 3520.0:
112
+ if context.get_value('arith_intensity') <= 3.994754433631897:
113
+ if str(context.get_value('mat2_dtype')) != 'torch.uint8':
114
+ if context.get_value('m*k') <= 18944.0:
115
+ return [(0.577, 5), (0.423, 6)]
116
+ else:
117
+ return [(0.988, 5), (0.012, 6)]
118
+ else:
119
+ if context.get_value('arith_intensity') <= 2.9899919033050537:
120
+ return None
121
+ else:
122
+ return None
123
+ else:
124
+ if context.get_value('arith_intensity') <= 7.956453561782837:
125
+ if context.get_value('k*n') <= 9244032.0:
126
+ return [(0.822, 5), (0.178, 6)]
127
+ else:
128
+ return [(0.977, 5), (0.023, 0)]
129
+ else:
130
+ if context.get_value('m*k') <= 978944.0:
131
+ return [(1.000, 5)]
132
+ else:
133
+ return [(0.971, 5), (0.029, 0)]
134
+ else:
135
+ if context.get_value('n') <= 13632.0:
136
+ if context.get_value('n') <= 6976.0:
137
+ return [(1.000, 6)]
138
+ else:
139
+ if context.get_value('k') <= 3968.0:
140
+ return [(0.617, 3), (0.111, 5), (0.099, 7), (0.086, 9), (0.062, 6), (0.025, 8)]
141
+ else:
142
+ return [(0.779, 8), (0.119, 5), (0.053, 7), (0.035, 6), (0.013, 3)]
143
+ else:
144
+ if context.get_value('k*n') <= 39518208.0:
145
+ return [(0.385, 4), (0.327, 3), (0.192, 6), (0.038, 7), (0.038, 10), (0.019, 5)]
146
+ else:
147
+ if context.get_value('n') <= 20800.0:
148
+ return [(0.821, 6), (0.121, 7), (0.029, 4), (0.014, 5), (0.007, 3), (0.007, 8)]
149
+ else:
150
+ return [(0.530, 7), (0.386, 6), (0.046, 8), (0.021, 3), (0.015, 4), (0.002, 5)]
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: B950
2
+ # fmt: off
3
+ # This file was generated by AutoHeuristic. Do not modify it manually!
4
+ # To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/
5
+ from typing import List, Optional, Tuple
6
+
7
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
8
+ AHContext,
9
+ AHMetadata,
10
+ Choice,
11
+ )
12
+ from torch._inductor.autoheuristic.learnedheuristic_interface import (
13
+ LearnedHeuristicDecision,
14
+ )
15
+
16
+
17
+ class MixedMMH100(LearnedHeuristicDecision):
18
+
19
+ def __init__(self) -> None:
20
+ self.choices: List[Choice] = []
21
+ self.fill_choices()
22
+
23
+ def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
24
+ return (
25
+ metadata.name == self.get_name()
26
+ and metadata.shared_memory == 232448
27
+ and str(metadata.device_capa) == "(9, 0)"
28
+ )
29
+
30
+ def get_confidence_threshold(self) -> float:
31
+ return 0.0
32
+
33
+ def get_choice(self, idx: int) -> Optional[str]:
34
+ if idx < len(self.choices):
35
+ return self.choices[idx]
36
+ return None
37
+
38
+ def fill_choices(self) -> None:
39
+ self.choices.append('extern_fallback_mixed_mm')
40
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
41
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
42
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
43
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
44
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
45
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
46
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
47
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4')
48
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8')
49
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
50
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
51
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
52
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
53
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
54
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
55
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
56
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
57
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
58
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
59
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
60
+
61
+ def get_name(self) -> str:
62
+ return 'mixed_mm'
63
+
64
+ def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]:
65
+ if context.get_value('arith_intensity') <= 15.988086223602295:
66
+ if context.get_value('n') <= 25280.0:
67
+ if context.get_value('n') <= 1344.0:
68
+ if context.get_value('mat1_stride_0') <= 7808.0:
69
+ return [(0.581, 7), (0.419, 6)]
70
+ else:
71
+ if context.get_value('m*n') <= 7680.0:
72
+ return [(0.875, 0), (0.125, 6)]
73
+ else:
74
+ return [(0.833, 0), (0.167, 7)]
75
+ else:
76
+ if context.get_value('n') <= 8512.0:
77
+ if str(context.get_value('mat2_dtype')) != 'torch.int8':
78
+ return [(0.763, 6), (0.237, 7)]
79
+ else:
80
+ return [(0.725, 7), (0.275, 6)]
81
+ else:
82
+ if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
83
+ return [(0.736, 7), (0.197, 9), (0.048, 6), (0.014, 8), (0.005, 10)]
84
+ else:
85
+ return [(0.473, 7), (0.398, 6), (0.097, 9), (0.032, 10)]
86
+ else:
87
+ if context.get_value('n') <= 42254.0:
88
+ if context.get_value('n') <= 33856.0:
89
+ if context.get_value('k*n') <= 68157440.0:
90
+ return [(0.370, 4), (0.370, 5), (0.074, 7), (0.074, 8), (0.074, 11), (0.037, 6)]
91
+ else:
92
+ return [(0.916, 8), (0.036, 7), (0.036, 9), (0.012, 4)]
93
+ else:
94
+ return [(0.659, 5), (0.341, 6)]
95
+ else:
96
+ if context.get_value('k*n') <= 326052992.0:
97
+ if context.get_value('n') <= 55232.0:
98
+ return [(0.571, 6), (0.321, 7), (0.036, 4), (0.036, 8), (0.036, 9)]
99
+ else:
100
+ return [(0.506, 6), (0.325, 8), (0.104, 7), (0.039, 5), (0.026, 9)]
101
+ else:
102
+ if context.get_value('n') <= 57024.0:
103
+ return [(0.462, 9), (0.385, 7), (0.115, 6), (0.038, 8)]
104
+ else:
105
+ return [(0.598, 8), (0.223, 9), (0.107, 6), (0.071, 7)]
106
+ else:
107
+ if context.get_value('m*n') <= 543936.0:
108
+ if str(context.get_value('17LEQmLEQ32')) != 'True':
109
+ if context.get_value('m*n') <= 262272.0:
110
+ if context.get_value('n') <= 1592.5:
111
+ return [(0.860, 0), (0.140, 9)]
112
+ else:
113
+ return None
114
+ else:
115
+ if context.get_value('m*k') <= 1294336.0:
116
+ return [(0.833, 17), (0.150, 18), (0.017, 15)]
117
+ else:
118
+ return [(0.917, 17), (0.083, 8)]
119
+ else:
120
+ if context.get_value('n') <= 12416.0:
121
+ if context.get_value('m*n') <= 43008.0:
122
+ return None
123
+ else:
124
+ return [(0.853, 14), (0.147, 9)]
125
+ else:
126
+ return [(0.625, 12), (0.375, 14)]
127
+ else:
128
+ if context.get_value('m') <= 32.5:
129
+ if context.get_value('mat2_stride_1') <= 6656.0:
130
+ if context.get_value('n') <= 69184.0:
131
+ return [(0.611, 12), (0.361, 14), (0.028, 13)]
132
+ else:
133
+ return [(1.000, 12)]
134
+ else:
135
+ if context.get_value('mat2_stride_1') <= 20864.0:
136
+ return [(1.000, 12)]
137
+ else:
138
+ return [(0.958, 12), (0.042, 9)]
139
+ else:
140
+ if context.get_value('m*n') <= 1085440.0:
141
+ if context.get_value('n') <= 9152.0:
142
+ return [(1.000, 18)]
143
+ else:
144
+ return [(0.780, 18), (0.160, 16), (0.060, 20)]
145
+ else:
146
+ if context.get_value('m') <= 67.0:
147
+ return [(0.650, 16), (0.203, 19), (0.122, 18), (0.016, 20), (0.008, 1)]
148
+ else:
149
+ return [(0.561, 3), (0.185, 16), (0.096, 20), (0.083, 19), (0.076, 2)]
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: B950
2
+ # fmt: off
3
+ # This file was generated by AutoHeuristic. Do not modify it manually!
4
+ # To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/
5
+ from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL
6
+ from torch._inductor.autoheuristic.learnedheuristic_interface import (
7
+ LearnedHeuristicRegression,
8
+ )
9
+
10
+
11
+ class PadMMA100(LearnedHeuristicRegression):
12
+
13
+ def __init__(self) -> None:
14
+ pass
15
+
16
+ def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
17
+ return (
18
+ metadata.name == self.get_name()
19
+ and metadata.shared_memory == 166912
20
+ and str(metadata.device_capa) == "(8, 0)"
21
+ )
22
+
23
+ def get_feedback(self, context: AHContext, choice: Choice) -> float:
24
+ context.context_dict[CHOICE_COL] = choice
25
+ return self.predict(context)
26
+
27
+ def get_confidence_threshold(self) -> float:
28
+ return 1.7025303314066
29
+
30
+ def get_name(self) -> str:
31
+ return 'pad_mm'
32
+
33
+ def predict(self, context: AHContext) -> float:
34
+ if str(context.get_value('choice')) != 'pad':
35
+ if str(context.get_value('using_tf32')) != 'False':
36
+ if context.get_value('m*n') <= 4171264.0:
37
+ if context.get_value('m*k') <= 3999308.0:
38
+ return 1.8751469764071178
39
+ else:
40
+ if str(context.get_value('n_multiple_32')) != 'True':
41
+ return 0.9117231355626345
42
+ else:
43
+ return 1.1607689608873861
44
+ else:
45
+ if str(context.get_value('n_multiple_2')) != 'True':
46
+ if str(context.get_value('using_tf32')) != 'True':
47
+ return 0.7430382200435992
48
+ else:
49
+ return 0.8531269794448678
50
+ else:
51
+ if str(context.get_value('k_multiple_2')) != 'True':
52
+ return 0.7577181972719917
53
+ else:
54
+ return 0.8977349440424219
55
+ else:
56
+ if context.get_value('m*n') <= 1299712.0:
57
+ return 1.1669723418995592
58
+ else:
59
+ if context.get_value('mat2_stride_1') <= 45217.5:
60
+ if context.get_value('m*n') <= 55884158.0:
61
+ return 1.0262769936909601
62
+ else:
63
+ return 1.0022677428470845
64
+ else:
65
+ if context.get_value('m') <= 18478.0:
66
+ return 1.1127066261894312
67
+ else:
68
+ return 1.0337740659894263
69
+ else:
70
+ if str(context.get_value('mat1_dtype')) != 'torch.float32':
71
+ if str(context.get_value('n_multiple_2')) != 'False':
72
+ if str(context.get_value('k_multiple_2')) != 'True':
73
+ if context.get_value('mat1_stride_0') <= 561.0:
74
+ return 1.2900382135142956
75
+ else:
76
+ return 1.5761737616057887
77
+ else:
78
+ if context.get_value('num_dims_needs_padding') <= 1.5:
79
+ return 1.0472263310239422
80
+ else:
81
+ return 1.1727673465762514
82
+ else:
83
+ if context.get_value('k') <= 28238.5:
84
+ if context.get_value('k/(m*n)') <= 0.00026227018679492176:
85
+ return 1.6770542505397175
86
+ else:
87
+ return 1.3974785435105923
88
+ else:
89
+ if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
90
+ return 1.3952699800111992
91
+ else:
92
+ return 1.5759286511628336
93
+ else:
94
+ if str(context.get_value('using_tf32')) != 'False':
95
+ if context.get_value('m*n') <= 14119424.0:
96
+ return 0.8875772670422478
97
+ else:
98
+ if str(context.get_value('mat2_innermost_needs_padding')) != 'True':
99
+ return 1.1467728924377265
100
+ else:
101
+ return 1.215842963532998
102
+ else:
103
+ if context.get_value('arith_intensity') <= 396.8774871826172:
104
+ return 0.89940161869551
105
+ else:
106
+ if context.get_value('mat2_stride_1') <= 45217.5:
107
+ return 0.9964328169353532
108
+ else:
109
+ return 0.9493479238294826
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py ADDED
File without changes
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from functools import partial
4
+ from typing import Any, Callable, Optional
5
+
6
+ import torch
7
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
8
+ AHContext,
9
+ AHMetadata,
10
+ AHOperation,
11
+ Choice,
12
+ CHOICE_COL,
13
+ Feedback,
14
+ FEEDBACK_COL,
15
+ get_metadata_str_from_log,
16
+ )
17
+ from torch._inductor.autoheuristic.learned_heuristic_controller import (
18
+ LearnedHeuristicController,
19
+ )
20
+ from torch._inductor.ir import ChoiceCaller
21
+ from torch._inductor.runtime.runtime_utils import cache_dir
22
+ from torch._inductor.utils import get_gpu_shared_memory
23
+
24
+
25
+ class LocalFeedback:
26
+ """
27
+ To be able to collect data for a choice, a function providing feedback given a choice has to be provided.
28
+ LocalFeedback can be used when AutoHeuristic should immediately run the function to collect feedback for each choice
29
+ (see pad_mm.py, where the autotuning happens locally, for an example).
30
+ """
31
+
32
+ def __init__(self, feedback_fn: Callable[[Choice], Feedback]) -> None:
33
+ self.feedback_fn = feedback_fn
34
+
35
+ def __call__(self, choice: Choice) -> Feedback:
36
+ return self.feedback_fn(choice)
37
+
38
+
39
+ class InconsistentMetadata(Exception):
40
+ """
41
+ Exception that is thrown when AutoHeuristic tries to log data to a file where the metadata stored in the file does
42
+ not match the metadata it would store if the file didn't exist.
43
+ """
44
+
45
+
46
+ class AutoHeuristic:
47
+ """
48
+ AutoHeuristic is a framework that allows one to collect data, learn a heuristic (i.e. a regression tree) and
49
+ generate the heuristic to code. This class allows one to collect data. The collected data can then be used to train
50
+ a heuristic (see torchgen/autoheuristic/).
51
+ """
52
+
53
+ collected_feedback: dict[Choice, Feedback]
54
+
55
+ def __init__(
56
+ self,
57
+ fallback: Callable[[], Choice],
58
+ choices: list[Choice],
59
+ feedback: Optional[LocalFeedback],
60
+ context: AHContext,
61
+ name: str,
62
+ augment_context: Optional[list[AHOperation]] = None,
63
+ precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
64
+ ) -> None:
65
+ """
66
+ Initializes an instance of the AutoHeuristic class.
67
+
68
+ Args:
69
+ fallback: A callable that returns a Choice when the heuristic is unsure which choice to make, or
70
+ AutoHeuristic is in data collection mode.
71
+ choices: A list of possible choices the heuristic can make.
72
+ feedback: An instance of LocalFeedback that provides feedback for a given choice.
73
+ context: Context to store with each choice and feedback.
74
+ name: A string that identifies the heuristic.
75
+ augment_context: An optional list of AHOperation instances that augment the context.
76
+ precondition: A callable that returns a boolean indicating whether AutoHeuristic should run.
77
+ """
78
+ self.fallback = fallback
79
+ self.choices = choices
80
+ self.feedback = feedback
81
+ self.context = context
82
+ self.name = name
83
+ self.collected_feedback = {}
84
+ self.augment_context = augment_context
85
+ self.metadata = AHMetadata(
86
+ get_gpu_shared_memory(),
87
+ torch.cuda.get_device_capability(),
88
+ self.choices,
89
+ self.name,
90
+ )
91
+ self.precondition = precondition
92
+
93
+ if not self.satisfies_precondition():
94
+ return
95
+
96
+ if torch._inductor.config.autoheuristic_log_path == "DEFAULT":
97
+ self.log_path = self.get_default_log_path()
98
+ else:
99
+ self.log_path = torch._inductor.config.autoheuristic_log_path
100
+
101
+ if torch._inductor.config.collect_autoheuristic(self.name):
102
+ if self.feedback is not None:
103
+ for choice in self.choices:
104
+ feedback_val = self.feedback(choice)
105
+ self.save_data(choice, feedback_val)
106
+
107
+ def satisfies_precondition(self) -> bool:
108
+ return self.precondition is None or self.precondition(
109
+ self.metadata, self.context
110
+ )
111
+
112
+ def get_choice(self) -> Choice:
113
+ """
114
+ Returns the chosen option based on the value of autoheuristic_use.
115
+ If self.name is one of the comma separated strings in autoheuristic_use,
116
+ it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option.
117
+ """
118
+
119
+ if not self.satisfies_precondition():
120
+ return self.fallback()
121
+
122
+ if torch._inductor.config.use_autoheuristic(self.name):
123
+ if self.augment_context is not None:
124
+ self.context.apply_operations(self.augment_context)
125
+ controller = LearnedHeuristicController(
126
+ self.metadata,
127
+ self.context,
128
+ )
129
+ decision = controller.get_decision()
130
+ if decision not in self.choices:
131
+ # TODO(AlnisM): We might want to allow this in the future
132
+ return self.fallback()
133
+ if decision is not None:
134
+ return decision
135
+ return self.fallback()
136
+
137
+ def get_top_k_choices(
138
+ self, top_k: int, always_included: Optional[list[str]] = None
139
+ ) -> Optional[list[Choice]]:
140
+ if not self.satisfies_precondition():
141
+ return None
142
+ if torch._inductor.config.use_autoheuristic(self.name):
143
+ if self.augment_context is not None:
144
+ self.context.apply_operations(self.augment_context)
145
+ controller = LearnedHeuristicController(
146
+ self.metadata,
147
+ self.context,
148
+ )
149
+ choices = controller.get_decisions_ranked(top_k)
150
+ if choices is None:
151
+ return None
152
+ if always_included is not None:
153
+ for choice in always_included:
154
+ if choice not in choices:
155
+ choices.append(choice)
156
+ return choices
157
+ return None
158
+
159
+ def get_collected_feedback(self, choice: Choice) -> Any:
160
+ return self.collected_feedback.get(choice, None)
161
+
162
+ @staticmethod
163
+ def get_device_identifier() -> str:
164
+ # a heuristic might work well for one GPU, but not for another
165
+ # we store the collected data per GPU model and learn a heuristic per GPU model
166
+
167
+ # TODO(AlnisM): just using the device name for now, but the same GPU model can have different names
168
+ device_name = torch.cuda.get_device_name().replace(" ", "_")
169
+ return device_name
170
+
171
+ def get_default_log_path(self) -> str:
172
+ device_name = self.get_device_identifier()
173
+ path = f"{cache_dir()}/autoheuristic/{device_name}/"
174
+ os.makedirs(path, exist_ok=True)
175
+ path += f"{self.name}.txt"
176
+ return path
177
+
178
+ def serialize_metadata(self) -> str:
179
+ metadata_dict = self.metadata.to_dict()
180
+ (
181
+ num_features,
182
+ cat_features,
183
+ ) = self.context.get_numerical_and_categorical_features()
184
+ metadata_dict["numerical_features"] = num_features
185
+ metadata_dict["categorical_features"] = cat_features
186
+ return json.dumps(metadata_dict)
187
+
188
+ def save_data(self, choice: Choice, feedback_val: Feedback) -> None:
189
+ self.collected_feedback[choice] = feedback_val
190
+ log_path = self.log_path
191
+
192
+ lines = []
193
+ log_exists = os.path.exists(log_path)
194
+ if log_exists:
195
+ # if log already exists, make sure it is consistent
196
+ metadata = self.serialize_metadata()
197
+ existing_metadata = get_metadata_str_from_log(self.log_path)
198
+ if existing_metadata != metadata:
199
+ raise InconsistentMetadata(
200
+ "Given metadata does not match existing metadata"
201
+ )
202
+ else:
203
+ lines.append(self.serialize_metadata())
204
+ feature_header = self.context.get_feature_names_csv()
205
+ header = feature_header + "," + CHOICE_COL + "," + FEEDBACK_COL
206
+ lines.append(header)
207
+
208
+ line = ""
209
+ feature_values = self.context.get_feature_values_csv()
210
+ line += feature_values + "," + choice + "," + str(feedback_val)
211
+ lines.append(line)
212
+
213
+ with open(log_path, "a") as f:
214
+ f.write("\n".join(lines) + "\n")
215
+
216
+
217
+ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
218
+ """
219
+ AutoHeuristicSelectAlgorithm is a subclass of AutoHeuristic that allows one to collect data and learn a heuristic
220
+ when one wants to use AutoHeuristic for kernel choice selection.
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ fallback: Callable[[], Optional[ChoiceCaller]],
226
+ choices: list[ChoiceCaller],
227
+ input_nodes: list[Any],
228
+ context: AHContext,
229
+ name: str,
230
+ augment_context: Optional[list[AHOperation]] = None,
231
+ precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
232
+ ) -> None:
233
+ """
234
+ The arguments choices, input_nodes and name have to match the ones used in the call to
235
+ autotune_select_algorithm(), e.g. if the following call is made
236
+ autotune_select_algorithm(name, choices, input_nodes, layout), the same name, choices and input_nodes
237
+ have to be used here.
238
+ """
239
+ self.input_nodes = input_nodes
240
+ self.choicestr2choice: dict[str, ChoiceCaller] = {}
241
+ for choice in choices:
242
+ self.choicestr2choice[choice.autoheuristic_id()] = choice
243
+ choices_str = list(self.choicestr2choice.keys())
244
+
245
+ def fallback_str() -> str:
246
+ fallback_choice = fallback()
247
+ if fallback_choice is None:
248
+ # TODO: Find a nicer way to handle this
249
+ return "unsure"
250
+ return fallback_choice.autoheuristic_id()
251
+
252
+ super().__init__(
253
+ fallback_str,
254
+ choices_str,
255
+ None,
256
+ context,
257
+ name,
258
+ augment_context,
259
+ precondition,
260
+ )
261
+
262
+ if (
263
+ torch._inductor.config.collect_autoheuristic(self.name)
264
+ and self.satisfies_precondition()
265
+ ):
266
+ self.register_global_feedback(input_nodes, choices)
267
+
268
+ def register_global_feedback(
269
+ self, input_nodes: list[Any], choices: list[ChoiceCaller]
270
+ ) -> None:
271
+ """
272
+ Registers a callback in select_algorithm, which is called with the timing of each choice.
273
+ """
274
+
275
+ from torch._inductor.select_algorithm import (
276
+ add_feedback_saver,
277
+ create_inputs_key,
278
+ create_precompile_key,
279
+ )
280
+
281
+ def store_global_feedback(
282
+ ah_inputs_key: str,
283
+ ah_precompile_key: str,
284
+ timings: dict[ChoiceCaller, float],
285
+ name: str,
286
+ input_nodes: list[Any],
287
+ choices: list[ChoiceCaller],
288
+ ) -> None:
289
+ current_inputs_key = create_inputs_key(input_nodes)
290
+ if current_inputs_key != ah_inputs_key:
291
+ return
292
+ current_precompile_key = create_precompile_key(
293
+ name, current_inputs_key, choices
294
+ )
295
+ if current_precompile_key != ah_precompile_key:
296
+ return
297
+ for choice, time in timings.items():
298
+ self.save_data(choice.autoheuristic_id(), time)
299
+
300
+ inputs_key = create_inputs_key(input_nodes)
301
+ precompile_key = create_precompile_key(self.name, inputs_key, choices)
302
+ feedback_saver = partial(store_global_feedback, inputs_key, precompile_key)
303
+ add_feedback_saver(feedback_saver)
304
+
305
+ def get_choice_caller(self) -> Optional[ChoiceCaller]:
306
+ choice = self.get_choice()
307
+ return self.choicestr2choice.get(choice, None)
308
+
309
+ def get_top_k_choices_caller(
310
+ self, top_k: int, always_included: Optional[list[str]] = None
311
+ ) -> Optional[list[ChoiceCaller]]:
312
+ choices = self.get_top_k_choices(top_k, always_included)
313
+ if choices is None:
314
+ return None
315
+ return [self.choicestr2choice[choice] for choice in choices]
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from typing import Any, Callable
3
+
4
+ import torch
5
+
6
+
7
+ Feedback = float
8
+ Choice = str
9
+ Value = Any
10
+
11
+ CHOICE_COL = "choice"
12
+ FEEDBACK_COL = "feedback"
13
+
14
+
15
+ class AHFeature:
16
+ """
17
+ The context, that AutoHeuristic stores, is a list of features. AutoHeuristic needs to know whether a feature is
18
+ categorical (i.e., not a continuous variable) to learn a machine learning model.
19
+ """
20
+
21
+ def __init__(self, name: str, value: Value, is_categorical: bool = False) -> None:
22
+ self.name = name
23
+ self.value = value
24
+ self.is_categorical = is_categorical
25
+
26
+
27
+ class AHOperation:
28
+ """
29
+ AHOperation can be used to augment the data collected by AutoHeuristic.
30
+ One might for example store features like m, k, n, but also want to use
31
+ features like m*n, or k*n, to learn a heuristic. Instead of storing features
32
+ that can be created from the collected data, one can use AHOperation to
33
+ create new features from the collected data.
34
+ """
35
+
36
+ def __init__(
37
+ self, name: str, func: Callable[[Any], Value], is_categorical: bool = False
38
+ ) -> None:
39
+ self.name = name
40
+ self.func = func
41
+ self.is_categorical = is_categorical
42
+
43
+ def apply_operation(self, data: Any) -> None:
44
+ data[self.name] = self.func(data)
45
+
46
+
47
+ class AHContext:
48
+ """
49
+ This class is used to specify which information AutoHeuristic should store. For each choice, AutoHeursitic will
50
+ store the context and the collected feedback. The context could be something like the shape of a tensor, i.e.,
51
+ information that will help to learn a heuristic.
52
+ """
53
+
54
+ features: list[AHFeature]
55
+ context_dict: dict[str, Value]
56
+
57
+ def __init__(self) -> None:
58
+ self.features = []
59
+ self.context_dict = {}
60
+
61
+ def add_feature(
62
+ self, name: str, value: Value, is_categorical: bool = False
63
+ ) -> None:
64
+ self.features.append(AHFeature(name, value, is_categorical=is_categorical))
65
+ self.context_dict[name] = value
66
+
67
+ def get_numerical_and_categorical_features(self) -> tuple[list[str], list[str]]:
68
+ numerical_features = []
69
+ categorical_features = []
70
+ for feature in self.features:
71
+ if feature.is_categorical:
72
+ categorical_features.append(feature.name)
73
+ else:
74
+ numerical_features.append(feature.name)
75
+
76
+ return numerical_features, categorical_features
77
+
78
+ def get_feature_names_csv(self) -> str:
79
+ return ",".join(feature.name for feature in self.features)
80
+
81
+ def get_feature_values_csv(self) -> str:
82
+ return ",".join(str(feature.value) for feature in self.features)
83
+
84
+ def get_value(self, name: str) -> Value:
85
+ return self.context_dict[name]
86
+
87
+ def apply_operations(self, operations: list[AHOperation]) -> None:
88
+ for op in operations:
89
+ op.apply_operation(self.context_dict)
90
+
91
+
92
+ class AHMetadata:
93
+ def __init__(
94
+ self,
95
+ shared_memory: Any,
96
+ device_capa: tuple[int, int],
97
+ choices: list[Choice],
98
+ name: str,
99
+ ) -> None:
100
+ # use amount of shared_memory and device_capability to identify GPU
101
+ # TODO(AlnisM): there might be a better way to do this
102
+ self.shared_memory = shared_memory
103
+ self.device_capa = device_capa
104
+ self.choices = choices
105
+ self.name = name
106
+
107
+ def to_dict(self) -> dict[str, Value]:
108
+ return {
109
+ "shared_memory": self.shared_memory,
110
+ "device_capa": self.device_capa,
111
+ "name": self.name,
112
+ }
113
+
114
+
115
+ def get_metadata_str_from_log(log_path: str) -> str:
116
+ with open(log_path, newline="") as file:
117
+ json_string = file.readline().strip()
118
+ return json_string
119
+
120
+
121
+ def check_minsize(context: AHContext, minsize: int) -> bool:
122
+ return (
123
+ context.get_value("m") >= minsize
124
+ and context.get_value("k") >= minsize
125
+ and context.get_value("n") >= minsize
126
+ )
127
+
128
+
129
+ def pad_mm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
130
+ if metadata.shared_memory == 166912 and metadata.device_capa == (8, 0):
131
+ # A100 precondition
132
+ return check_minsize(context, 512)
133
+ elif metadata.shared_memory == 232448 and metadata.device_capa == (9, 0):
134
+ # H100 precondition
135
+ return check_minsize(context, 768)
136
+ return True
137
+
138
+
139
+ def get_mixedmm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
140
+ m = context.get_value("m")
141
+ k = context.get_value("k")
142
+ n = context.get_value("n")
143
+ if m > 128 or k < 1024 or n < 1024:
144
+ return False
145
+ mat1_iscontig = context.get_value("mat1_iscontig")
146
+ mat2_iscontig = context.get_value("mat2_iscontig")
147
+ return mat1_iscontig and not mat2_iscontig
148
+
149
+
150
+ def get_mult_dims_ops() -> list[AHOperation]:
151
+ m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"])
152
+ m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"])
153
+ k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"])
154
+ return [m_times_k_op, m_times_n_op, k_times_n_op]
155
+
156
+
157
+ def get_arith_intensity(data: Any) -> float:
158
+ m = data["m"]
159
+ k = data["k"]
160
+ n = data["n"]
161
+ if m == 0 or k == 0 or n == 0:
162
+ return 0.0
163
+ return m * k * n / (m * k + k * n + m * n)
164
+
165
+
166
+ def pad_mm_operations() -> list[AHOperation]:
167
+ mult_dims_ops = get_mult_dims_ops()
168
+ k_div_m_times_n_op = AHOperation(
169
+ "k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"])
170
+ )
171
+
172
+ def bfloat_perf_hit(data: Any) -> bool:
173
+ m = data["m"]
174
+ k = data["k"]
175
+ n = data["n"]
176
+ is_bfloat = str(data["mat1_dtype"]) == "torch.bfloat16"
177
+ return k > (m * 1024) and k > (n * 1024) and is_bfloat
178
+
179
+ bfloat_perf_hit_op = AHOperation(
180
+ "bfloat_perf_hit", bfloat_perf_hit, is_categorical=True
181
+ )
182
+
183
+ arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
184
+ dims_need_padding_ops = get_dims_need_padding_ops()
185
+ dims_multiple_ops = get_dims_multiple_ops()
186
+ is_contig_ops = get_is_contig_ops()
187
+
188
+ ah_operations = mult_dims_ops + [
189
+ k_div_m_times_n_op,
190
+ bfloat_perf_hit_op,
191
+ arith_intensity_op,
192
+ ]
193
+ ah_operations.extend(dims_need_padding_ops)
194
+ ah_operations.extend(dims_multiple_ops)
195
+ ah_operations.extend(is_contig_ops)
196
+ return ah_operations
197
+
198
+
199
+ def between_op(data: Any, dim: str, lower: int, upper: int) -> bool:
200
+ return data[dim] >= lower and data[dim] <= upper
201
+
202
+
203
+ def between_ops() -> list[AHOperation]:
204
+ dims = ["m", "k", "n"]
205
+ limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)]
206
+ ah_operations = []
207
+ for dim in dims:
208
+ for lower, upper in limits:
209
+ between_op_fn = functools.partial(
210
+ between_op, dim=dim, lower=lower, upper=upper
211
+ )
212
+ # using 'LEQ' instead of '<=' because '<=' cannot be exported to dot
213
+ between_op_name = f"{lower}LEQ{dim}LEQ{upper}"
214
+ ah_operations.append(
215
+ AHOperation(between_op_name, between_op_fn, is_categorical=True)
216
+ )
217
+ return ah_operations
218
+
219
+
220
+ def pow2_op(data: Any, dim: str, exponent: int) -> bool:
221
+ return data[dim] == 2**exponent
222
+
223
+
224
+ def mm_operations() -> list[AHOperation]:
225
+ mult_dims_ops = get_mult_dims_ops()
226
+ arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
227
+ return mult_dims_ops + [arith_intensity_op]
228
+
229
+
230
+ def mixed_mm_operations() -> list[AHOperation]:
231
+ return mm_operations() + between_ops()
232
+
233
+
234
+ def is_multiple(data: Any, dim: str, mult: int) -> bool:
235
+ return data[dim] % mult == 0
236
+
237
+
238
+ def get_dims_multiple_ops() -> list[AHOperation]:
239
+ multiples = [2, 4, 8, 16, 32]
240
+ dims = ["m", "k", "n"]
241
+ dims_multiple_ops = []
242
+ for dim in dims:
243
+ for mult in multiples:
244
+ is_multiple_fn = functools.partial(is_multiple, dim=dim, mult=mult)
245
+ dims_multiple_op = AHOperation(
246
+ f"{dim}_multiple_{mult}", is_multiple_fn, is_categorical=True
247
+ )
248
+ dims_multiple_ops.append(dims_multiple_op)
249
+ return dims_multiple_ops
250
+
251
+
252
+ def get_dims_need_padding_ops() -> list[AHOperation]:
253
+ def mat1_innermost_needs_padding_fn(data: Any) -> bool:
254
+ mat1_stride_0 = data["mat1_stride_0"]
255
+ mat1_stride_1 = data["mat1_stride_1"]
256
+ m_padded_length = data["m_padded_length"]
257
+ k_padded_length = data["k_padded_length"]
258
+ mat1_innermost_needs_padding = False
259
+ if mat1_stride_0 == 1 and m_padded_length != 0:
260
+ mat1_innermost_needs_padding = True
261
+ if mat1_stride_1 == 1 and k_padded_length != 0:
262
+ mat1_innermost_needs_padding = True
263
+ return mat1_innermost_needs_padding
264
+
265
+ mat1_innermost_op = AHOperation(
266
+ "mat1_innermost_needs_padding",
267
+ mat1_innermost_needs_padding_fn,
268
+ is_categorical=True,
269
+ )
270
+
271
+ def mat2_innermost_needs_padding_fn(data: Any) -> bool:
272
+ mat2_stride_0 = data["mat2_stride_0"]
273
+ mat2_stride_1 = data["mat2_stride_1"]
274
+ k_padded_length = data["k_padded_length"]
275
+ n_padded_length = data["n_padded_length"]
276
+ mat2_innermost_needs_padding = False
277
+ if mat2_stride_0 == 1 and k_padded_length != 0:
278
+ mat2_innermost_needs_padding = True
279
+ if mat2_stride_1 == 1 and n_padded_length != 0:
280
+ mat2_innermost_needs_padding = True
281
+ return mat2_innermost_needs_padding
282
+
283
+ mat2_innermost_op = AHOperation(
284
+ "mat2_innermost_needs_padding",
285
+ mat2_innermost_needs_padding_fn,
286
+ is_categorical=True,
287
+ )
288
+
289
+ def num_dims_needs_padding_fn(data: Any) -> int:
290
+ m_padded_length = data["m_padded_length"]
291
+ k_padded_length = data["k_padded_length"]
292
+ n_padded_length = data["n_padded_length"]
293
+ num_dims_needs_padding = 0
294
+ if m_padded_length != 0:
295
+ num_dims_needs_padding += 1
296
+ if k_padded_length != 0:
297
+ num_dims_needs_padding += 1
298
+ if n_padded_length != 0:
299
+ num_dims_needs_padding += 1
300
+ return num_dims_needs_padding
301
+
302
+ num_dims_op = AHOperation("num_dims_needs_padding", num_dims_needs_padding_fn)
303
+ return [mat1_innermost_op, mat2_innermost_op, num_dims_op]
304
+
305
+
306
+ def get_is_contig_ops() -> list[AHOperation]:
307
+ def mat1_is_contig_fn(data: Any) -> bool:
308
+ stride_0 = data["mat1_stride_0"]
309
+ stride_1 = data["mat1_stride_1"]
310
+ k = data["k"]
311
+ return stride_0 == k and stride_1 == 1
312
+
313
+ mat1_is_contig_op = AHOperation(
314
+ "mat1_iscontig", mat1_is_contig_fn, is_categorical=True
315
+ )
316
+
317
+ def mat2_is_contig_fn(data: Any) -> bool:
318
+ stride_0 = data["mat2_stride_0"]
319
+ stride_1 = data["mat2_stride_1"]
320
+ n = data["n"]
321
+ return stride_0 == n and stride_1 == 1
322
+
323
+ mat2_is_contig_op = AHOperation(
324
+ "mat2_iscontig", mat2_is_contig_fn, is_categorical=True
325
+ )
326
+
327
+ return [mat1_is_contig_op, mat2_is_contig_op]
328
+
329
+
330
+ def context_add_strides(context: AHContext, name: str, stride: tuple[int, ...]) -> None:
331
+ for i, s in enumerate(stride):
332
+ context.add_feature(f"{name}_stride_{i}", s)
333
+
334
+
335
+ def context_add_using_tf32(context: AHContext, dtype: torch.dtype) -> None:
336
+ using_tf32 = "not_float_32"
337
+ if dtype == torch.float32:
338
+ using_tf32 = torch.backends.cuda.matmul.allow_tf32
339
+ context.add_feature("using_tf32", using_tf32, is_categorical=True)
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import inspect
3
+ import pkgutil
4
+ from collections import defaultdict
5
+ from typing import Any, Optional
6
+
7
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
8
+ AHContext,
9
+ AHMetadata,
10
+ Choice,
11
+ )
12
+ from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeuristic
13
+
14
+
15
+ def find_and_instantiate_subclasses(
16
+ package_name: str, base_class: Any
17
+ ) -> list[LearnedHeuristic]:
18
+ instances = []
19
+
20
+ package = importlib.import_module(package_name)
21
+ for _, module_name, _ in pkgutil.walk_packages(
22
+ package.__path__, package.__name__ + "."
23
+ ):
24
+ try:
25
+ module_basename = module_name.split(".")[-1]
26
+ if not module_basename.startswith("_"):
27
+ # learned heuristics start with an underscore
28
+ continue
29
+ module = importlib.import_module(module_name)
30
+
31
+ # look for classes that are subclasses of base_class
32
+ for _name, obj in inspect.getmembers(module):
33
+ if (
34
+ inspect.isclass(obj)
35
+ and issubclass(obj, base_class)
36
+ and obj != base_class
37
+ ):
38
+ instance = obj()
39
+ instances.append(instance)
40
+ except Exception as e:
41
+ print(f"Error processing module {module_name}: {e}")
42
+
43
+ return instances
44
+
45
+
46
+ class LearnedHeuristicController:
47
+ """
48
+ Class that finds and instantiates all learned heuristics. It also provides
49
+ a way to get the decision of a learned heuristic.
50
+ """
51
+
52
+ existing_heuristics: dict[str, list[LearnedHeuristic]] = defaultdict(list)
53
+ """
54
+ A dictionary that stores all the learned heuristics for each optimization.
55
+ The key is the optimization name, and the value is a list of LearnedHeuristic objects.
56
+ """
57
+
58
+ heuristics_initialized: bool = False
59
+ """
60
+ A flag that indicates whether the learned heuristics have been initialized.
61
+ Set to true when the get_decision() function is called for the first time.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ metadata: AHMetadata,
67
+ context: AHContext,
68
+ ) -> None:
69
+ self.metadata = metadata
70
+ self.context = context
71
+
72
+ def get_heuristics(self, name: str) -> list[LearnedHeuristic]:
73
+ """
74
+ Returns a list of learned heuristics for the given optimization name.
75
+ """
76
+
77
+ if not LearnedHeuristicController.heuristics_initialized:
78
+ # learned heuristics are generated into the following package
79
+ learned_heuristics_package = "torch._inductor.autoheuristic.artifacts"
80
+
81
+ # learned heuristics have to be of type LearnedHeuristic
82
+ base_class = LearnedHeuristic
83
+ found_heuristics = find_and_instantiate_subclasses(
84
+ learned_heuristics_package, base_class
85
+ )
86
+
87
+ for learned_heuristic in found_heuristics:
88
+ opt_name = learned_heuristic.get_name()
89
+ LearnedHeuristicController.existing_heuristics[opt_name].append(
90
+ learned_heuristic
91
+ )
92
+ LearnedHeuristicController.heuristics_initialized = True
93
+
94
+ return LearnedHeuristicController.existing_heuristics[name]
95
+
96
+ def get_decision(self) -> Optional[Choice]:
97
+ """
98
+ Returns the decision made by the learned heuristic or None if no heuristic was found or the heuristic is unsure
99
+ which choice to make.
100
+ """
101
+
102
+ heuristics = self.get_heuristics(self.metadata.name)
103
+ for heuristic in heuristics:
104
+ if heuristic.check_precondition(self.metadata, self.context):
105
+ return heuristic.get_decision(self.context, self.metadata.choices)
106
+ return None
107
+
108
+ def get_decisions_ranked(self, top_k: int) -> Optional[list[Choice]]:
109
+ heuristics = self.get_heuristics(self.metadata.name)
110
+ for heuristic in heuristics:
111
+ if heuristic.check_precondition(self.metadata, self.context):
112
+ choices = heuristic.get_decisions_ranked(self.context)
113
+ if choices is None:
114
+ return None
115
+ avail_choices = [
116
+ choice for choice in choices if choice in self.metadata.choices
117
+ ]
118
+ return avail_choices[:top_k]
119
+ return None
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+ from typing import Optional
3
+
4
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
5
+ AHContext,
6
+ AHMetadata,
7
+ Choice,
8
+ )
9
+
10
+
11
+ class LearnedHeuristic:
12
+ """
13
+ LearnedHeuristic is a base class for all learned heuristics.
14
+ """
15
+
16
+ def __init__(self) -> None:
17
+ pass
18
+
19
+ def check_precondition(
20
+ self,
21
+ metadata: AHMetadata,
22
+ context: AHContext,
23
+ ) -> bool:
24
+ return True
25
+
26
+ def get_decision(
27
+ self, context: AHContext, choices: list[Choice]
28
+ ) -> Optional[Choice]:
29
+ return None
30
+
31
+ def get_confidence_threshold(self) -> float:
32
+ return 1.0
33
+
34
+ def get_name(self) -> str:
35
+ return ""
36
+
37
+ def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]:
38
+ return None
39
+
40
+
41
+ class LearnedHeuristicRegression(LearnedHeuristic):
42
+ def __init__(self) -> None:
43
+ super().__init__()
44
+
45
+ def get_feedback(self, context: AHContext, choice: Choice) -> float:
46
+ return 1.0
47
+
48
+ def get_decision(
49
+ self, context: AHContext, choices: list[Choice]
50
+ ) -> Optional[Choice]:
51
+ choice2feedback = {}
52
+ for choice in choices:
53
+ predicted_feedback = self.get_feedback(context, choice)
54
+ choice2feedback[choice] = predicted_feedback
55
+ sorted_choices_feedback = sorted(
56
+ choice2feedback.items(), key=operator.itemgetter(1)
57
+ )
58
+ highest_feedback = sorted_choices_feedback[-1][1]
59
+ second_highest_feedback = sorted_choices_feedback[-2][1]
60
+ if highest_feedback / second_highest_feedback > self.get_confidence_threshold():
61
+ return sorted_choices_feedback[-1][0]
62
+ # We are not sure which choice is the best one
63
+ return None
64
+
65
+
66
+ class LearnedHeuristicDecision(LearnedHeuristic):
67
+ def __init__(self) -> None:
68
+ super().__init__()
69
+
70
+ def get_choice(self, idx: int) -> Optional[str]:
71
+ return None
72
+
73
+ def get_decision(
74
+ self, context: AHContext, choices: list[Choice]
75
+ ) -> Optional[Choice]:
76
+ best_choices = self.get_best_choices(context)
77
+ if not best_choices:
78
+ return None
79
+ (best_choice_proba, best_choice_idx) = best_choices[0]
80
+ if best_choice_proba <= self.get_confidence_threshold():
81
+ return None
82
+ return self.get_choice(best_choice_idx)
83
+
84
+ def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]:
85
+ feedback_idx_list = self.get_best_choices(context)
86
+ if feedback_idx_list is None:
87
+ return None
88
+ choices = [
89
+ self.get_choice(feedback_idx[1]) for feedback_idx in feedback_idx_list
90
+ ]
91
+ choices = [choice for choice in choices if choice is not None]
92
+ return choices
93
+
94
+ def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]:
95
+ return []
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__init__.py ADDED
File without changes
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import torch
4
+ from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE
5
+
6
+
7
+ # It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like:
8
+ # "...
9
+ # from ..codecache import CudaKernelParamCache
10
+ # ..."
11
+ # In such cases, we do not need to hipify_torch the original class/file name in codegen/codecache
12
+
13
+
14
+ def maybe_hipify_code_wrapper(source_codes: str, force_hipify: bool = False) -> str:
15
+ if torch.version.hip is None and not force_hipify:
16
+ return source_codes
17
+
18
+ def c2_repl(m: re.Match[str]) -> object:
19
+ return PYTORCH_MAP[m.group(0)]
20
+
21
+ # We need to redefine RE_PYTORCH_PREPROCESSOR here since in hipify_torch,
22
+ # it will apply positive lookbehind (?<=\W) to the pattern to avoid matching
23
+ # keyword at the beginning of code line. However, this can happen in codegen,
24
+ # which will cause the pattern to not match.
25
+
26
+ # Note that lookahead (?=\W) is still needed to keep hipification idomponent, for example
27
+ # we need to skip replacing "getStreamFromExternal" in "getStreamFromExternalMasqueradingAsCUDA"
28
+ RE_PYTORCH_PREPROCESSOR = re.compile(rf"({PYTORCH_TRIE.export_to_regex()})(?=\W)")
29
+
30
+ source_codes = RE_PYTORCH_PREPROCESSOR.sub(c2_repl, source_codes) # type: ignore[arg-type]
31
+ return source_codes
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Definition of AOTI runtime interface functions
2
+
3
+ #include <torch/csrc/inductor/aoti_runtime/interface.h>
4
+ #include <torch/csrc/inductor/aoti_runtime/model_container.h>
5
+
6
+ #include <iostream>
7
+ #include <vector>
8
+
9
+ #define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \
10
+ try { \
11
+ __VA_ARGS__ \
12
+ } catch (const std::exception& e) { \
13
+ std::cerr << "Error: " << e.what() << '\n'; \
14
+ return AOTI_RUNTIME_FAILURE; \
15
+ } catch (...) { \
16
+ std::cerr << "Unknown exception occurred.\n"; \
17
+ return AOTI_RUNTIME_FAILURE; \
18
+ } \
19
+ return AOTI_RUNTIME_SUCCESS;
20
+
21
+ #define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name) \
22
+ do { \
23
+ AOTI_RUNTIME_CHECK( \
24
+ actual_size == expected_size, \
25
+ "expected " + std::string(name) + " vector size to be " + \
26
+ std::to_string(expected_size) + ", but got " + \
27
+ std::to_string(actual_size)); \
28
+ } while (0)
29
+
30
+ // AOTInductor uses at::addmm_out, which doesn't supports
31
+ // arguments that requires gradient. For this reason, we
32
+ // enforce no_grad context for run APIs.
33
+ //
34
+ // A RAII, thread local (!) guard that enables or disables grad mode upon
35
+ // construction, and sets it back to the original value upon destruction.
36
+ struct AOTINoGradGuard {
37
+ AOTINoGradGuard() {
38
+ aoti_torch_grad_mode_set_enabled(false);
39
+ }
40
+ AOTINoGradGuard(const AOTINoGradGuard&) = delete;
41
+ AOTINoGradGuard(AOTINoGradGuard&&) noexcept = delete;
42
+ ~AOTINoGradGuard() {
43
+ aoti_torch_grad_mode_set_enabled(prev_mode);
44
+ }
45
+ AOTINoGradGuard& operator=(const AOTINoGradGuard&) = delete;
46
+ AOTINoGradGuard& operator=(AOTINoGradGuard&&) noexcept = delete;
47
+ bool prev_mode{aoti_torch_grad_mode_is_enabled()};
48
+ };
49
+
50
+ extern "C" {
51
+
52
+ AOTIRuntimeError AOTInductorModelContainerCreate(
53
+ AOTInductorModelContainerHandle* container_handle,
54
+ size_t num_models,
55
+ bool is_cpu,
56
+ const char* cubin_dir) {
57
+ return AOTInductorModelContainerCreateWithDevice(
58
+ container_handle,
59
+ num_models,
60
+ is_cpu ? "cpu" : "cuda",
61
+ cubin_dir);
62
+ }
63
+
64
+ AOTIRuntimeError AOTInductorModelContainerCreateWithDevice(
65
+ AOTInductorModelContainerHandle* container_handle,
66
+ size_t num_models,
67
+ const char* device_str,
68
+ const char* cubin_dir) {
69
+ if (num_models == 0) {
70
+ std::cerr << "Error: num_models must be positive, but got 0\n";
71
+ return AOTI_RUNTIME_FAILURE;
72
+ }
73
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
74
+ std::optional<std::string> cubin_dir_opt;
75
+ if (cubin_dir != nullptr) {
76
+ cubin_dir_opt.emplace(cubin_dir);
77
+ }
78
+ auto* container = new torch::aot_inductor::AOTInductorModelContainer(
79
+ num_models, std::string(device_str), cubin_dir_opt);
80
+ *container_handle =
81
+ reinterpret_cast<AOTInductorModelContainerHandle>(container);
82
+ })
83
+ }
84
+
85
+ AOTIRuntimeError AOTInductorModelContainerDelete(
86
+ AOTInductorModelContainerHandle container_handle) {
87
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
88
+ auto* container =
89
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
90
+ container_handle);
91
+ delete container;
92
+ });
93
+ }
94
+
95
+ AOTIRuntimeError AOTInductorModelContainerRun(
96
+ AOTInductorModelContainerHandle container_handle,
97
+ AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles
98
+ // are stolen; the array itself is borrowed
99
+ size_t num_inputs,
100
+ AtenTensorHandle*
101
+ output_handles, // array for writing output AtenTensorHandle; handles
102
+ // will be stolen by the caller; the array itself is
103
+ // borrowed
104
+ size_t num_outputs,
105
+ AOTInductorStreamHandle stream_handle,
106
+ AOTIProxyExecutorHandle proxy_executor_handle) {
107
+ auto* container =
108
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
109
+ container_handle);
110
+ AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs");
111
+ AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs");
112
+
113
+ auto stream =
114
+ reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
115
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
116
+ AOTINoGradGuard guard;
117
+ container->run(
118
+ input_handles, output_handles, stream, proxy_executor_handle);
119
+ })
120
+ }
121
+
122
+ AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded(
123
+ AOTInductorModelContainerHandle container_handle,
124
+ AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles
125
+ // are stolen; the array itself is borrowed
126
+ size_t num_inputs,
127
+ AtenTensorHandle*
128
+ output_handles, // array for writing output AtenTensorHandle; handles
129
+ // will be stolen by the caller; the array itself is
130
+ // borrowed
131
+ size_t num_outputs,
132
+ AOTInductorStreamHandle stream_handle,
133
+ AOTIProxyExecutorHandle proxy_executor_handle) {
134
+ auto* container =
135
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
136
+ container_handle);
137
+ AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs");
138
+ AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs");
139
+
140
+ auto stream =
141
+ reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
142
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
143
+ AOTINoGradGuard guard;
144
+ container->run_single_threaded(
145
+ input_handles, output_handles, stream, proxy_executor_handle);
146
+ })
147
+ }
148
+
149
+ AOTIRuntimeError AOTInductorModelContainerGetNumConstants(
150
+ AOTInductorModelContainerHandle container_handle,
151
+ size_t* num_constants) {
152
+ auto* container =
153
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
154
+ container_handle);
155
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
156
+ { *num_constants = container->num_constants(); })
157
+ }
158
+
159
+ AOTIRuntimeError AOTInductorModelContainerGetConstantName(
160
+ AOTInductorModelContainerHandle container_handle,
161
+ size_t idx,
162
+ const char** name) {
163
+ auto* container =
164
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
165
+ container_handle);
166
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
167
+ { *name = container->constant_name(idx); })
168
+ }
169
+
170
+ AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN(
171
+ AOTInductorModelContainerHandle container_handle,
172
+ size_t idx,
173
+ const char** original_fqn) {
174
+ auto* container =
175
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
176
+ container_handle);
177
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
178
+ { *original_fqn = container->constant_original_fqn(idx); })
179
+ }
180
+
181
+ AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded(
182
+ AOTInductorModelContainerHandle container_handle,
183
+ size_t idx,
184
+ bool* from_folded) {
185
+ auto* container =
186
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(container_handle);
187
+ CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); })
188
+ }
189
+
190
+ AOTIRuntimeError AOTInductorModelContainerGetConstantType(
191
+ AOTInductorModelContainerHandle container_handle,
192
+ size_t idx,
193
+ int32_t* type) {
194
+ auto* container =
195
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(container_handle);
196
+ CONVERT_EXCEPTION_TO_ERROR_CODE({ *type = container->constant_type(idx); })
197
+ }
198
+
199
+ AOTIRuntimeError AOTInductorModelContainerGetConstantDtype(
200
+ AOTInductorModelContainerHandle container_handle,
201
+ size_t idx,
202
+ int32_t* dtype) {
203
+ auto* container =
204
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
205
+ container_handle);
206
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
207
+ { *dtype = container->constant_dtype(idx); })
208
+ }
209
+
210
+ AOTIRuntimeError AOTInductorModelContainerGetConstantDataSize(
211
+ AOTInductorModelContainerHandle container_handle,
212
+ size_t idx,
213
+ size_t* data_size) {
214
+ auto* container =
215
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
216
+ container_handle);
217
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
218
+ { *data_size = container->constant_data_size(idx); })
219
+ }
220
+
221
+ AOTIRuntimeError AOTInductorModelContainerExtractConstantsMap(
222
+ AOTInductorModelContainerHandle container_handle,
223
+ AOTInductorConstantMapHandle constant_map_handle,
224
+ bool use_inactive) {
225
+ auto* container =
226
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
227
+ container_handle);
228
+ auto constants_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
229
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
230
+ { const auto ret = container->extract_constants_map(use_inactive);
231
+ for (const auto& pair: ret) {
232
+ constants_map->emplace(pair.first, pair.second);
233
+ }
234
+ })
235
+ }
236
+
237
+ AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBuffer(
238
+ AOTInductorModelContainerHandle container_handle,
239
+ AOTInductorConstantMapHandle constant_map_handle,
240
+ bool use_inactive,
241
+ bool validate_full_update) {
242
+ auto* container =
243
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
244
+ container_handle);
245
+ auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
246
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
247
+ container->update_constant_buffer(
248
+ *input_map, use_inactive, validate_full_update, /* user_managed = */ true);
249
+ })
250
+ }
251
+
252
+ AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer(
253
+ AOTInductorModelContainerHandle container_handle,
254
+ AOTInductorConstantMapHandle constant_map_handle,
255
+ bool use_inactive,
256
+ bool validate_full_update) {
257
+ auto* container =
258
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
259
+ container_handle);
260
+ auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
261
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
262
+ container->update_constant_buffer(
263
+ *input_map, use_inactive, validate_full_update);
264
+ })
265
+ }
266
+
267
+ AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer(
268
+ AOTInductorModelContainerHandle container_handle,
269
+ AOTInductorConstantMapHandle constant_map_handle) {
270
+ return AOTInductorModelContainerUpdateConstantBuffer(container_handle,
271
+ constant_map_handle,
272
+ /*use_inactive*/ true,
273
+ /*validate_full_update*/ true);
274
+ }
275
+
276
+ AOTIRuntimeError AOTInductorModelContainerFreeInactiveConstantBuffer(
277
+ AOTInductorModelContainerHandle container_handle) {
278
+ auto* container =
279
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
280
+ container_handle);
281
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
282
+ container->free_inactive_constant_buffer();
283
+ })
284
+ }
285
+
286
+ AOTIRuntimeError AOTInductorModelContainerRunConstantFolding(
287
+ AOTInductorModelContainerHandle container_handle,
288
+ bool use_inactive,
289
+ AOTInductorStreamHandle stream_handle,
290
+ AOTIProxyExecutorHandle proxy_executor_handle) {
291
+ auto* container =
292
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
293
+ container_handle);
294
+ auto stream =
295
+ reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
296
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
297
+ AOTINoGradGuard guard;
298
+ container->run_const_fold(use_inactive, stream, proxy_executor_handle);
299
+ })
300
+ }
301
+
302
+ AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer(
303
+ AOTInductorModelContainerHandle container_handle) {
304
+ auto* container =
305
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
306
+ container_handle);
307
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
308
+ container->swap_constant_buffer();
309
+ })
310
+ }
311
+
312
+ AOTIRuntimeError AOTInductorModelContainerGetNumInputs(
313
+ AOTInductorModelContainerHandle container_handle,
314
+ size_t* ret_num_inputs) {
315
+ auto* container =
316
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
317
+ container_handle);
318
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
319
+ { *ret_num_inputs = container->num_inputs(); })
320
+ }
321
+
322
+ AOTIRuntimeError AOTInductorModelContainerGetInputName(
323
+ AOTInductorModelContainerHandle container_handle,
324
+ size_t input_idx,
325
+ const char** ret_input_names) {
326
+ auto* container =
327
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
328
+ container_handle);
329
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
330
+ { *ret_input_names = container->input_name(input_idx); })
331
+ }
332
+
333
+ AOTIRuntimeError AOTInductorModelContainerGetNumOutputs(
334
+ AOTInductorModelContainerHandle container_handle,
335
+ size_t* ret_num_outputs) {
336
+ auto* container =
337
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
338
+ container_handle);
339
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
340
+ { *ret_num_outputs = container->num_outputs(); })
341
+ }
342
+
343
+ AOTIRuntimeError AOTInductorModelContainerGetOutputName(
344
+ AOTInductorModelContainerHandle container_handle,
345
+ size_t output_idx,
346
+ const char** ret_output_names) {
347
+ auto* container =
348
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
349
+ container_handle);
350
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
351
+ { *ret_output_names = container->output_name(output_idx); })
352
+ }
353
+
354
+ AOTIRuntimeError AOTInductorModelContainerGetCallSpec(
355
+ AOTInductorModelContainerHandle container_handle,
356
+ const char** in_spec,
357
+ const char** out_spec) {
358
+ auto* container =
359
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
360
+ container_handle);
361
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
362
+ *in_spec = container->get_in_spec();
363
+ *out_spec = container->get_out_spec();
364
+ })
365
+ }
366
+
367
+ AOTIRuntimeError AOTInductorModelCreate(
368
+ AOTInductorModelHandle* model_handle,
369
+ AOTInductorConstantMapHandle constant_map_handle){
370
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
371
+ auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
372
+ auto constant_array = std::make_shared<std::vector<torch::aot_inductor::ConstantHandle>>();
373
+ auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
374
+
375
+ auto model = new torch::aot_inductor::AOTInductorModel(
376
+ constant_map,
377
+ constant_array,
378
+ "cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models
379
+ ""
380
+ );
381
+
382
+ if (input_map) {
383
+ for (auto const& kv : *input_map) {
384
+ constant_map->emplace(kv.first, kv.second);
385
+ }
386
+ } else {
387
+ model->load_constants();
388
+ }
389
+
390
+ *model_handle = reinterpret_cast<AOTInductorModelHandle>(model);
391
+ })}
392
+
393
+ AOTIRuntimeError AOTInductorModelRun(
394
+ AOTInductorModelHandle model_handle,
395
+ AtenTensorHandle* input_handles,
396
+ AtenTensorHandle* output_handles) {
397
+ auto model =
398
+ reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
399
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
400
+ AOTINoGradGuard guard;
401
+ model->run_impl(
402
+ input_handles,
403
+ output_handles,
404
+ (torch::aot_inductor::DeviceStreamType) nullptr,
405
+ nullptr);
406
+ })
407
+ }
408
+
409
+ AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){
410
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
411
+ auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(
412
+ model_handle);
413
+ delete model;
414
+ })}
415
+
416
+ AOTIRuntimeError AOTInductorModelGetNumOutputs(
417
+ AOTInductorModelHandle model_handle,
418
+ size_t* ret_num_outputs) {
419
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
420
+ auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
421
+ *ret_num_outputs = model->num_outputs();
422
+ })
423
+ }
424
+
425
+ AOTIRuntimeError AOTInductorModelUpdateConstantsMap(
426
+ AOTInductorModelHandle model_handle,
427
+ AOTInductorConstantMapHandle constant_map_handle) {
428
+ auto model =
429
+ reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
430
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
431
+ auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
432
+ auto input_map =
433
+ reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(
434
+ constant_map_handle);
435
+
436
+ for (auto const& kv : *input_map) {
437
+ constant_map->emplace(kv.first, kv.second);
438
+ }
439
+ model->update_constants_map(std::move(constant_map));
440
+ })
441
+ }
442
+
443
+ } // extern "C"
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/block_analysis.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import functools
3
+ import textwrap
4
+ from typing import Optional
5
+
6
+ import sympy
7
+ from sympy import Expr, Symbol
8
+
9
+ from torch.utils._sympy.functions import FloorDiv, ModularIndexing
10
+
11
+ from ..utils import sympy_dot, sympy_subs
12
+ from ..virtualized import V
13
+
14
+
15
+ class BlockPatternMatcher:
16
+ """
17
+ Matches block indexing expressions.
18
+ """
19
+
20
+ @classmethod
21
+ def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr:
22
+ """
23
+ Given a sympy expression, return the subexpression comprised only of terms
24
+ involving the specified symbol.
25
+
26
+ For example, if `expr` is `x * 5 + x ** 2 + y * 2 + 5`, and `symbol` is `x`,
27
+ this returns `x * 5 + x ** 2`.
28
+ """
29
+ expr = cls._preprocess(expr)
30
+ return sympy.S.Zero + sum(
31
+ term for term in sympy.Add.make_args(expr) if symbol in term.free_symbols
32
+ )
33
+
34
+ @staticmethod
35
+ def get_slice_numels(dims: list[Expr]) -> list[Expr]:
36
+ """
37
+ Compute the cumulative size of each dimension's slice.
38
+ This proceeds from the last dim up to the second.
39
+ """
40
+ numels = collections.deque([sympy.S.One])
41
+ for dim in dims[:0:-1]:
42
+ numel = dim * numels[0]
43
+ numels.appendleft(numel)
44
+ return [*numels]
45
+
46
+ @staticmethod
47
+ def _preprocess(expr: Expr) -> Expr:
48
+ # Remove any Identity nodes, e.g. expand x + (5 * y) to x + 5 * y.
49
+ return expr.expand(identity=True)
50
+
51
+ @classmethod
52
+ def match_mod_div_block_expr(
53
+ cls,
54
+ index: Expr,
55
+ index_var: Symbol,
56
+ numel: Expr,
57
+ num_dims: int,
58
+ ) -> Optional[tuple[list[Expr], list[Expr], list[Expr]]]:
59
+ """
60
+ Matches modular indexing expressions, converting them to implied block dimensions and strides.
61
+ See triton.py for more information.
62
+ """
63
+ index = cls._preprocess(index)
64
+
65
+ # Pattern match to find the strides and offset.
66
+ wild = functools.partial(sympy.Wild, exclude=[index_var])
67
+ dims: list[Expr] = [wild(f"dim_mod{idx}") for idx in range(num_dims)]
68
+ strides: list[Expr] = [wild(f"stride_mod{idx}") for idx in range(num_dims)]
69
+
70
+ # The first dimension's index is computed by division.
71
+ # The remaining are computed by modulo.
72
+ slice_numels = cls.get_slice_numels(dims[:num_dims])
73
+ block_index_exprs = [FloorDiv(index_var, slice_numels[0])] + [
74
+ ModularIndexing(index_var, numel, dim)
75
+ for dim, numel in zip(dims[1:], slice_numels[1:])
76
+ ]
77
+
78
+ # Calculate a linear index from block indices.
79
+ match_expr = sympy_dot(strides, block_index_exprs)
80
+
81
+ # Heuristic: if the number of dimensions is high, check that the minimum requirements
82
+ # are met before attempting an expensive full match. see triton.py:match_mod_div_block
83
+ # for more details. In short, here we check that each subexpression in sympy.Add contains
84
+ # only FloorDiv or ModularIndexing expressions.
85
+ if num_dims >= 5:
86
+ stride, denom, other = sympy.symbols("stride denominator other", cls=wild)
87
+ mod_div_pattern = stride * ModularIndexing(index_var, denom, other)
88
+ floor_div_pattern = stride * FloorDiv(index_var, denom)
89
+ first_dim_floor_div_matched = False
90
+ match_failed = False
91
+ for arg in sympy.Add.make_args(index):
92
+ if arg.match(floor_div_pattern):
93
+ # There should only be a single FloorDiv(index, denom) expression
94
+ # corresponding to the first dimension
95
+ if first_dim_floor_div_matched:
96
+ match_failed = True
97
+ break
98
+ first_dim_floor_div_matched = True
99
+ elif arg.match(mod_div_pattern):
100
+ continue
101
+ else:
102
+ match_failed = True
103
+ break
104
+
105
+ if match_failed:
106
+ return None
107
+
108
+ # Pattern match.
109
+ match = index.match(match_expr)
110
+ if match is None:
111
+ return None
112
+
113
+ # Provide default values for unmatched dims and strides.
114
+ for dim in dims[1:]:
115
+ if dim not in match:
116
+ match[dim] = sympy.S.One
117
+ for stride in strides[1:]:
118
+ if stride not in match:
119
+ match[stride] = sympy.S.Zero
120
+
121
+ sizevars = V.graph.sizevars
122
+
123
+ def get_match(expr: Expr) -> Expr:
124
+ return sizevars.lookup_precomputed_size(match[expr])
125
+
126
+ # Replace wildcards with matched expressions.
127
+ dims = [dims[0]] + [get_match(dim) for dim in dims[1:]]
128
+ strides = [get_match(stride) for stride in strides]
129
+ slice_numels = cls.get_slice_numels(dims)
130
+ block_index_exprs = [sympy_subs(expr, match) for expr in block_index_exprs]
131
+
132
+ # The leading dimension is not directly matched in our expression.
133
+ # We solve for it by dividing the range tree numel by the product of
134
+ # all other dimensions. We quit if they are not known to be divisible.
135
+ assert dims[0] not in match, "Expected not to match the leading dimension!"
136
+ if not sizevars.statically_known_multiple_of(numel, slice_numels[0]):
137
+ return None
138
+ dims[0] = numel / slice_numels[0]
139
+
140
+ # Sanity check that we can recover the index from the matched subexpressions.
141
+ matched_index = sympy_dot(strides, block_index_exprs)
142
+ assert sizevars.statically_known_equals(
143
+ # New precomputed replacements may be generated when the `get_match` function
144
+ # above is called, but the `index` that is being matched has not been updated.
145
+ # So remove them when checking for equivalence e.g. if ps0=3*s0 and
146
+ # index=3*s0*expr, matched_index=ps0*expr, then index == matched_index
147
+ sizevars.remove_precomputed_replacements(matched_index),
148
+ sizevars.remove_precomputed_replacements(index),
149
+ ), textwrap.dedent(
150
+ f"""
151
+ Invalid match!
152
+ Index: {index}
153
+ Matched expression: {matched_index}
154
+ """
155
+ )
156
+
157
+ return dims, strides, block_index_exprs
158
+
159
+ @classmethod
160
+ def match_affine_block_expr(
161
+ cls,
162
+ index: Expr,
163
+ index_var: Symbol,
164
+ ) -> Optional[Expr]:
165
+ """
166
+ Matches simple expressions of the form stride * index, returning the
167
+ stride.
168
+ """
169
+ index = cls._preprocess(index)
170
+ stride = sympy.Wild("stride", exclude=[index_var])
171
+ m = index.match(index_var * stride)
172
+ if m is None:
173
+ return None
174
+
175
+ return m[stride]
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/common.py ADDED
@@ -0,0 +1,2691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import atexit
4
+ import contextlib
5
+ import dataclasses
6
+ import enum
7
+ import functools
8
+ import itertools
9
+ import logging
10
+ import math
11
+ import operator
12
+ import os
13
+ import re
14
+ import tempfile
15
+ from abc import ABC, abstractmethod
16
+ from enum import auto, Enum
17
+ from itertools import chain
18
+ from typing import (
19
+ Any,
20
+ Callable,
21
+ cast,
22
+ ClassVar,
23
+ Generic,
24
+ NamedTuple,
25
+ Optional,
26
+ TYPE_CHECKING,
27
+ Union,
28
+ )
29
+ from typing_extensions import Self, TypeVar
30
+
31
+ import sympy
32
+
33
+ import torch
34
+ import torch.fx
35
+ from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
36
+ from torch.utils import _pytree as pytree
37
+ from torch.utils._ordered_set import OrderedSet
38
+ from torch.utils._sympy.numbers import int_oo
39
+ from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter
40
+ from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
41
+ from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
42
+
43
+ from .. import config, metrics
44
+ from ..dtype_propagation import DtypePropagationOpsHandler
45
+ from ..ops_handler import BasicMathOpsMixin, DefaultHandler
46
+ from ..utils import (
47
+ boolean_ops,
48
+ DeferredLineBase,
49
+ generate_assert,
50
+ get_current_backend,
51
+ IndentedBuffer,
52
+ ir_dataclass,
53
+ ScopedDict,
54
+ sympy_dot,
55
+ sympy_index_symbol,
56
+ sympy_subs,
57
+ triton_type,
58
+ unique,
59
+ )
60
+ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
61
+
62
+
63
+ if TYPE_CHECKING:
64
+ from collections.abc import Iterator, MutableMapping, Sequence
65
+
66
+ from torch.fx import GraphModule
67
+
68
+ from ..custom_graph_pass import CustomGraphModulePass
69
+ from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
70
+ from ..loop_body import LoopBody
71
+ from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
72
+ from .wrapper import PythonWrapperCodegen
73
+
74
+ _T = TypeVar("_T")
75
+ SchedulingConstructor = Callable[[Optional[Scheduler]], BaseScheduling]
76
+ WrapperConstructor = type[PythonWrapperCodegen]
77
+ SymbolLike = Union[str, sympy.Symbol]
78
+
79
+ # OpVarT should really be Union[CSEVariable, str], however this
80
+ # causes typing errors in subclasses (defined in other files).
81
+ OpVarT = str
82
+
83
+ schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
84
+ log = logging.getLogger(__name__)
85
+
86
+
87
+ def data_type_logger(msg: str) -> None:
88
+ if schedule_log.isEnabledFor(logging.DEBUG):
89
+ schedule_log.debug("Data type propagation: %s", msg)
90
+
91
+
92
+ @dataclasses.dataclass
93
+ class FileBackedGraphModule:
94
+ """
95
+ Output of FX wrapper codegen. Exposes the same methods as ModuleType, but these
96
+ map back to a GraphModule instead of Python source.
97
+ """
98
+
99
+ gm: GraphModule
100
+ compiled_fn: Callable[..., Any]
101
+
102
+ def __post_init__(self) -> None:
103
+ # Write the code to a file for compatibility with debugging utilities.
104
+ # The file is deleted upon program termination.
105
+ self.tempfile = tempfile.NamedTemporaryFile(
106
+ mode="w+", suffix=".py", delete=False
107
+ )
108
+ atexit.register(os.remove, self.tempfile.name)
109
+ with self.tempfile as f:
110
+ f.write(self.value)
111
+
112
+ @property
113
+ def __file__(self) -> str:
114
+ return self.tempfile.name
115
+
116
+ def call(self, args: list[Any]) -> Any:
117
+ return self.compiled_fn(*args)
118
+
119
+ @property
120
+ def value(self) -> str:
121
+ return self.gm.code
122
+
123
+
124
+ class WorkspaceZeroMode(enum.Enum):
125
+ UNINITIALIZED = 0
126
+ ZERO_ON_CALL = 1 # kernel may leave workspace dirty
127
+ ZERO_PER_GRAPH = 2 # must be re-zeroed by kernel
128
+
129
+ @staticmethod
130
+ def combine(a: WorkspaceZeroMode, b: WorkspaceZeroMode) -> WorkspaceZeroMode:
131
+ if a == b or b == WorkspaceZeroMode.UNINITIALIZED:
132
+ return a
133
+ if a == WorkspaceZeroMode.UNINITIALIZED:
134
+ return b
135
+ raise NotImplementedError(f"WorkspaceZeroMode.combine({a!r}, {b!r})")
136
+
137
+ @staticmethod
138
+ def from_bool(zero_fill: bool) -> WorkspaceZeroMode:
139
+ if zero_fill:
140
+ return WorkspaceZeroMode.ZERO_ON_CALL
141
+ return WorkspaceZeroMode.UNINITIALIZED
142
+
143
+
144
+ class CodegenSymbol(ABC):
145
+ """
146
+ An IR object possibly corresponding to a variable in the wrapper code.
147
+ """
148
+
149
+ @abstractmethod
150
+ def get_name(self) -> str:
151
+ pass
152
+
153
+ @abstractmethod
154
+ def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
155
+ pass
156
+
157
+
158
+ @ir_dataclass(frozen=True)
159
+ class WorkspaceArg(CodegenSymbol):
160
+ """A temporary buffer used for a single kernel, then discarded.
161
+
162
+ Not registered as a traditional buffer since there are no users,
163
+ so it would be dead code eliminated.
164
+
165
+ Args:
166
+ nbytes: The size of the buffer in bytes.
167
+ zero_fill: Whether the buffer should be initialized to zero.
168
+
169
+ """
170
+
171
+ count: sympy.Expr
172
+ zero_mode: WorkspaceZeroMode
173
+ device: torch.device
174
+ outer_name: str
175
+ inner_name: str = "ws_ptr"
176
+ dtype: torch.dtype = torch.uint8
177
+
178
+ @staticmethod
179
+ def unique_name(prefix: str = "workspace_") -> str:
180
+ return f"{prefix}{next(V.graph.workspace_id)}"
181
+
182
+ @staticmethod
183
+ def can_join(a: WorkspaceArg, b: WorkspaceArg) -> bool:
184
+ return (
185
+ a.inner_name == b.inner_name and a.dtype == b.dtype and a.device == b.device
186
+ )
187
+
188
+ @staticmethod
189
+ def join(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg:
190
+ return WorkspaceArg(
191
+ count=a.count + b.count,
192
+ zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode),
193
+ dtype=a.dtype,
194
+ device=a.device,
195
+ inner_name=a.inner_name,
196
+ outer_name=a.outer_name,
197
+ )
198
+
199
+ @staticmethod
200
+ def maximum(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg:
201
+ assert (
202
+ a.dtype == b.dtype and a.device == b.device and a.inner_name == b.inner_name
203
+ )
204
+ return WorkspaceArg(
205
+ count=sympy.Max(a.count, b.count),
206
+ zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode),
207
+ dtype=a.dtype,
208
+ device=a.device,
209
+ inner_name=a.inner_name,
210
+ outer_name=a.outer_name,
211
+ )
212
+
213
+ # These methods let WorkspaceArg pretend it is a buffer to reuse allocation code
214
+ def get_device(self) -> torch.device:
215
+ return self.device
216
+
217
+ get_device_or_error = get_device
218
+
219
+ def get_dtype(self) -> torch.dtype:
220
+ return self.dtype
221
+
222
+ def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
223
+ return self.get_layout().get_example()
224
+
225
+ def get_layout(self) -> FixedLayout:
226
+ from ..ir import FixedLayout
227
+
228
+ return FixedLayout(
229
+ device=self.device,
230
+ dtype=self.dtype,
231
+ size=[self.count],
232
+ stride=[1],
233
+ )
234
+
235
+ @property
236
+ def layout(self) -> FixedLayout:
237
+ return self.get_layout()
238
+
239
+ get_output_spec = get_layout
240
+ maybe_get_output_spec = get_layout
241
+ maybe_get_layout = get_layout
242
+
243
+ def get_offset(self) -> sympy.Expr:
244
+ return sympy.S.Zero
245
+
246
+ def get_size(self) -> list[sympy.Expr]:
247
+ return [self.count]
248
+
249
+ def get_stride(self) -> list[sympy.Expr]:
250
+ return [sympy.S.One]
251
+
252
+ def get_name(self) -> str:
253
+ return self.outer_name
254
+
255
+ def get_inputs_that_alias_output(self) -> list[str]:
256
+ return []
257
+
258
+
259
+ class TritonScratchWorkspace:
260
+ def __init__(self, size: int, generate_dtype_str: Callable[..., str]):
261
+ self.size = size
262
+ self._generate_dtype_str = generate_dtype_str
263
+
264
+ def generate_dtype_str(self) -> str:
265
+ return self._generate_dtype_str()
266
+
267
+
268
+ @dataclasses.dataclass
269
+ class TensorArg:
270
+ name: str
271
+ buffer: str
272
+ dtype: torch.dtype
273
+ offset: sympy.Expr = sympy.S.Zero # c++ only
274
+ alias_of: Optional[str] = None # halide only
275
+
276
+
277
+ @dataclasses.dataclass
278
+ class SizeArg:
279
+ name: str
280
+ expr: sympy.Expr
281
+
282
+ @property
283
+ def alias_of(self) -> Optional[str]:
284
+ return None
285
+
286
+
287
+ @dataclasses.dataclass
288
+ class ConstexprArg:
289
+ name: str
290
+
291
+
292
+ @dataclasses.dataclass
293
+ class TMADescriptorArg:
294
+ name: str
295
+ api_type: str # "experimental" or "stable"
296
+ block_shape: Optional[list[sympy.Expr]] # only needed for "stable"
297
+ dtype: Optional[torch.dtype] # only needed for "stable"
298
+
299
+
300
+ @dataclasses.dataclass
301
+ class DeviceCodegen:
302
+ scheduling: SchedulingConstructor
303
+ wrapper_codegen: WrapperConstructor
304
+ cpp_wrapper_codegen: Optional[WrapperConstructor] = None
305
+
306
+
307
+ KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg, ConstexprArg]
308
+
309
+ device_codegens: dict[str, DeviceCodegen] = {}
310
+
311
+
312
+ class DeviceOpOverrides:
313
+ def import_get_raw_stream_as(self, name: str) -> str:
314
+ raise NotImplementedError
315
+
316
+ def set_device(self, device_idx: int) -> str:
317
+ raise NotImplementedError
318
+
319
+ def synchronize(self) -> str:
320
+ raise NotImplementedError
321
+
322
+ def device_guard(self, device_idx: int) -> str:
323
+ raise NotImplementedError
324
+
325
+ def cpp_device_guard(self) -> str:
326
+ raise NotImplementedError
327
+
328
+ def cpp_aoti_device_guard(self) -> str:
329
+ raise NotImplementedError
330
+
331
+ def cpp_stream_guard(self) -> str:
332
+ raise NotImplementedError
333
+
334
+ def cpp_aoti_stream_guard(self) -> str:
335
+ raise NotImplementedError
336
+
337
+ def cpp_getStreamFromExternal(self) -> str:
338
+ raise NotImplementedError
339
+
340
+ def kernel_header(self) -> str:
341
+ raise NotImplementedError
342
+
343
+ def kernel_driver(self) -> str:
344
+ raise NotImplementedError
345
+
346
+ def cpp_stream_type(self) -> str:
347
+ raise NotImplementedError
348
+
349
+ def aoti_get_stream(self) -> str:
350
+ raise NotImplementedError
351
+
352
+ def cpp_kernel_type(self) -> str:
353
+ raise NotImplementedError
354
+
355
+ def cpp_device_ptr(self) -> str:
356
+ raise NotImplementedError
357
+
358
+ def tma_descriptor_helpers(self) -> str:
359
+ raise NotImplementedError
360
+
361
+ def cpp_global_scratch(
362
+ self, idx: int, workspace: TritonScratchWorkspace
363
+ ) -> Optional[tuple[list[str], str]]:
364
+ # optionally return (scratch definition, arg name)
365
+ raise NotImplementedError
366
+
367
+
368
+ device_op_overrides_dict: dict[str, DeviceOpOverrides] = {}
369
+ custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {}
370
+
371
+
372
+ # The code generated by Inductor consists of two main parts: kernel code and wrapper code.
373
+ # For any new backend looking to integrate with Inductor, customization of these two main
374
+ # parts are necessary to generate its specific code.
375
+ #
376
+ # Kernel code generation is determined by different Scheduling. Consequently, a new
377
+ # backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
378
+ # CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
379
+ #
380
+ # For the Wrapper, Inductor provides a PythonWrapperCodegen class to generate the Python wrapper code
381
+ # that bridges kernels. This allows out-of-tree backends to inherit from PythonWrapperCodegen,
382
+ # and override specific member functions to create backend-specific Python wrapper code.
383
+ #
384
+ # Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
385
+ # of the logic for either Scheduling or PythonWrapperCodegen. So the Scheduling and PythonWrapperCodegen interfaces
386
+ # provide flexibility to the backend. A backend can choose to implement these classes from scratch,
387
+ # or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
388
+ # register_backend_for_device, to equip a new backend at runtime.
389
+ #
390
+ # Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
391
+ # This backend can be used as a reference:
392
+ # https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
393
+ def register_backend_for_device(
394
+ device: str,
395
+ device_scheduling: SchedulingConstructor,
396
+ device_wrapper_codegen: WrapperConstructor,
397
+ device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None,
398
+ device_custom_pass: Optional[CustomGraphModulePass] = None,
399
+ ) -> None:
400
+ device_codegens[device] = DeviceCodegen(
401
+ device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
402
+ )
403
+ custom_backend_passes[device] = device_custom_pass
404
+
405
+
406
+ class BackendFeature(Enum):
407
+ FOREACH = auto()
408
+ BUCKETIZE = auto()
409
+ INPLACE_BUFFERS = auto()
410
+ MASKED_SCATTER_WITH_INDEX = auto()
411
+ SCAN = auto()
412
+ SORT = auto()
413
+ TUPLE_REDUCTION = auto()
414
+ PREFER_STORE_LOOP_ORDER = auto()
415
+ TRITON_TEMPLATES = auto()
416
+ REDUCE_TO_SINGLE_ELEMENT = auto()
417
+
418
+
419
+ def get_backend_features(
420
+ device: Union[torch.device, str, None],
421
+ ) -> OrderedSet[BackendFeature]:
422
+ if device is None:
423
+ return OrderedSet()
424
+ init_backend_registration()
425
+ if isinstance(device, torch.device):
426
+ device_type = device.type
427
+ else:
428
+ assert isinstance(device, str), type(device)
429
+ device_type = device
430
+ device = torch.device(device_type)
431
+ scheduling_ctor = get_scheduling_for_device(device_type)
432
+ assert scheduling_ctor
433
+ scheduling = scheduling_ctor(None)
434
+ return scheduling.get_backend_features(device)
435
+
436
+
437
+ def has_backend_feature(
438
+ device: Union[torch.device, str, None], feature: BackendFeature
439
+ ) -> bool:
440
+ """See also V.graph.has_feature"""
441
+ assert isinstance(feature, BackendFeature)
442
+ return feature in get_backend_features(device)
443
+
444
+
445
+ def get_scheduling_for_device(device: str) -> Optional[SchedulingConstructor]:
446
+ return device_codegens[device].scheduling if device in device_codegens else None
447
+
448
+
449
+ def get_wrapper_codegen_for_device(
450
+ device: str, cpp_wrapper: bool = False
451
+ ) -> Optional[WrapperConstructor]:
452
+ if device in device_codegens:
453
+ wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
454
+ return (
455
+ wrapper_codegen_obj.cpp_wrapper_codegen
456
+ if cpp_wrapper
457
+ else wrapper_codegen_obj.wrapper_codegen
458
+ )
459
+ return None
460
+
461
+
462
+ def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModulePass]:
463
+ return custom_backend_passes[device] if device in custom_backend_passes else None
464
+
465
+
466
+ @functools.cache
467
+ def init_backend_registration() -> None:
468
+ from .cpp import CppScheduling
469
+ from .cpp_wrapper_cpu import CppWrapperCpu
470
+ from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef
471
+ from .cpp_wrapper_gpu import CppWrapperGpu
472
+ from .cpp_wrapper_mps import CppWrapperMps
473
+ from .cuda_combined_scheduling import CUDACombinedScheduling
474
+ from .halide import HalideScheduling
475
+ from .mps import MetalScheduling
476
+ from .triton import TritonScheduling
477
+ from .wrapper import PythonWrapperCodegen
478
+
479
+ if get_scheduling_for_device("cpu") is None:
480
+ cpu_backends = {
481
+ "cpp": CppScheduling,
482
+ "halide": HalideScheduling,
483
+ "triton": TritonScheduling,
484
+ }
485
+ register_backend_for_device(
486
+ "cpu",
487
+ lambda scheduling: cpu_backends[config.cpu_backend](scheduling),
488
+ PythonWrapperCodegen,
489
+ CppWrapperCpuArrayRef
490
+ if config.aot_inductor.allow_stack_allocation
491
+ else CppWrapperCpu,
492
+ )
493
+
494
+ if get_scheduling_for_device("cuda") is None:
495
+ # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
496
+ cuda_backends = {
497
+ "triton": CUDACombinedScheduling,
498
+ "halide": HalideScheduling,
499
+ }
500
+ register_backend_for_device(
501
+ "cuda",
502
+ lambda scheduling: cuda_backends[config.cuda_backend](scheduling),
503
+ PythonWrapperCodegen,
504
+ CppWrapperGpu,
505
+ )
506
+
507
+ if get_scheduling_for_device("xpu") is None:
508
+ register_backend_for_device(
509
+ "xpu",
510
+ TritonScheduling,
511
+ PythonWrapperCodegen,
512
+ CppWrapperGpu,
513
+ )
514
+
515
+ if get_scheduling_for_device("mps") is None:
516
+ register_backend_for_device(
517
+ "mps",
518
+ MetalScheduling,
519
+ PythonWrapperCodegen,
520
+ CppWrapperMps,
521
+ )
522
+
523
+ private_backend = torch._C._get_privateuse1_backend_name()
524
+ if (
525
+ private_backend != "privateuseone"
526
+ and get_scheduling_for_device(private_backend) is None
527
+ ):
528
+ from torch.utils.backend_registration import _get_custom_mod_func
529
+
530
+ try:
531
+ device_scheduling = _get_custom_mod_func("Scheduling")
532
+ wrapper_codegen = _get_custom_mod_func("PythonWrapperCodegen")
533
+ cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodegen")
534
+ if device_scheduling and wrapper_codegen and cpp_wrapper_codegen:
535
+ register_backend_for_device(
536
+ private_backend,
537
+ device_scheduling,
538
+ wrapper_codegen,
539
+ cpp_wrapper_codegen,
540
+ )
541
+ except RuntimeError:
542
+ pass
543
+
544
+
545
+ def index_prevent_reordering(
546
+ index: Sequence[sympy.Expr],
547
+ index_vars: Sequence[sympy.Expr],
548
+ sizes: Sequence[sympy.Expr],
549
+ ) -> list[sympy.Expr]:
550
+ from ..ir import FlexibleLayout
551
+
552
+ # added contiguous index prevents reordering
553
+ return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
554
+
555
+
556
+ def register_device_op_overrides(
557
+ device: str, device_op_overrides: DeviceOpOverrides
558
+ ) -> None:
559
+ device_op_overrides_dict[device] = device_op_overrides
560
+
561
+
562
+ def get_device_op_overrides(device: str) -> DeviceOpOverrides:
563
+ assert isinstance(device, str), type(device)
564
+
565
+ if not device_op_overrides_dict:
566
+ from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401
567
+ from .cuda import device_op_overrides # noqa: F401
568
+ from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
569
+
570
+ return device_op_overrides_dict[device]
571
+
572
+
573
+ DTYPE_TO_COMPUTATION_DTYPE: dict[torch.dtype, torch.dtype] = {
574
+ torch.bfloat16: torch.float,
575
+ torch.float16: torch.float,
576
+ **{
577
+ dtype: dtype
578
+ for dtype in [
579
+ torch.bool,
580
+ torch.float32,
581
+ torch.float64,
582
+ torch.int8,
583
+ torch.int16,
584
+ torch.int32,
585
+ torch.int64,
586
+ torch.uint8,
587
+ torch.uint16,
588
+ torch.uint32,
589
+ torch.uint64,
590
+ ]
591
+ },
592
+ }
593
+
594
+
595
+ def deduce_output_dtype_by_name(
596
+ op_name: str,
597
+ *args: Any,
598
+ **kwargs: Any,
599
+ ) -> Optional[torch.dtype]:
600
+ """
601
+ Given op name and a list of input dtypes, deduce the output dtype
602
+ """
603
+ if op_name in boolean_ops():
604
+ return torch.bool
605
+ elif op_name in (
606
+ "to_dtype",
607
+ "index_expr",
608
+ ):
609
+ return kwargs["dtype"] if "dtype" in kwargs else args[-1]
610
+ elif op_name in (
611
+ "rand",
612
+ "randn",
613
+ ):
614
+ return torch.float
615
+ elif op_name in (
616
+ "get_index",
617
+ "randint64",
618
+ "load_seed",
619
+ ):
620
+ return torch.int64
621
+ elif op_name == "reduction":
622
+ return kwargs["dtype"] if "dtype" in kwargs else args[1]
623
+ elif op_name == "constant":
624
+ return kwargs["dtype"] if "dtype" in kwargs else args[-1]
625
+ elif op_name in (
626
+ "load",
627
+ "store",
628
+ "store_reduction",
629
+ ):
630
+ buf_name = args[1]
631
+ return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
632
+ elif op_name == "to_dtype_bitcast":
633
+ return kwargs["dtype"] if "dtype" in kwargs else args[-2]
634
+ return None
635
+
636
+
637
+ def check_dtype(
638
+ buffer: IndentedBuffer, var: CSEVariableType, dtype: torch.dtype
639
+ ) -> None:
640
+ backend = get_current_backend()
641
+ if config.test_configs.runtime_triton_dtype_assert and backend == "triton":
642
+ buffer.writeline(f"tl.static_assert({var}.dtype == {triton_type(dtype)})")
643
+ elif config.test_configs.static_cpp_dtype_assert and backend == "cpp":
644
+ from .cpp_utils import CppCSEVariable, DTYPE_TO_CPP
645
+
646
+ assert isinstance(var, CppCSEVariable), type(var)
647
+ if dtype == torch.bool:
648
+ if var.is_vec:
649
+ is_same_dt = f"IsVecMaskType<decltype({var})>::value"
650
+ else:
651
+ # operator&(bool, bool) returns int and it can be used as boolean in C++
652
+ is_same_dt = f"std::is_same_v<decltype({var}), bool> || std::is_same_v<decltype({var}), int>"
653
+ else:
654
+ c_var_type = f"decltype({var})"
655
+ if var.is_vec:
656
+ c_var_type = f"typename {c_var_type}::value_type"
657
+ is_same_dt = f"std::is_same_v<{c_var_type}, {DTYPE_TO_CPP[dtype]}>"
658
+
659
+ buffer.writeline(f"static_assert({is_same_dt});")
660
+
661
+
662
+ class DataTypePropagation:
663
+ def __init__(self, body: LoopBody) -> None:
664
+ self.body = body
665
+ self.graphs: dict[Union[Callable[..., Any], str], Any] = {
666
+ "root": body.root_block.graph
667
+ }
668
+ for k, v in body.subblocks.items():
669
+ self.graphs[k] = v.graph
670
+
671
+ def deduce_node_dtype_by_inputs(self, node: torch.fx.Node) -> Optional[torch.dtype]:
672
+ inputs = node.all_input_nodes
673
+ input_nodes = [
674
+ n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
675
+ ]
676
+ if len(input_nodes) == 0:
677
+ return None
678
+
679
+ all_input_nodes_propagated = all(
680
+ OptimizationContext.key in n.meta
681
+ and n.meta[OptimizationContext.key].dtype is not None
682
+ for n in input_nodes
683
+ )
684
+ if not all_input_nodes_propagated:
685
+ return None
686
+
687
+ return functools.reduce(
688
+ torch.promote_types,
689
+ [n.meta[OptimizationContext.key].dtype for n in input_nodes],
690
+ )
691
+
692
+ def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node) -> torch.dtype:
693
+ sub_graph = self.graphs[node.target]
694
+ dtype = self.propagate_graph(sub_graph)
695
+ assert dtype
696
+ return dtype
697
+
698
+ def deduce_node_dtype(self, node: torch.fx.Node) -> Optional[torch.dtype]:
699
+ if node.op == "placeholder":
700
+ return None
701
+
702
+ if node.target == "output" and len(node.args) != 1:
703
+ # we can infer output node if it only have 1 arg
704
+ return None
705
+
706
+ if node.target == operator.getitem:
707
+ node_arg = node.args[0]
708
+ assert isinstance(node_arg, torch.fx.Node), type(node_arg)
709
+ return self.deduce_node_dtype(node_arg)
710
+
711
+ assert isinstance(node.target, str), type(node.target)
712
+
713
+ if node.target.startswith("masked_subblock"):
714
+ return self.deduce_node_dtype_by_subgraph(node)
715
+
716
+ if (
717
+ output_dtype := deduce_output_dtype_by_name(
718
+ node.target,
719
+ *node.args,
720
+ **node.kwargs,
721
+ )
722
+ ) is not None:
723
+ return output_dtype
724
+
725
+ return self.deduce_node_dtype_by_inputs(node)
726
+
727
+ def propagate_graph(self, graph: torch.fx.Graph) -> Optional[torch.dtype]:
728
+ assert graph.nodes
729
+ graph_dtype: Optional[torch.dtype] = None
730
+ # For masked_subblock, we use output's dtype to represent
731
+ # the dtype of this subgraph. For other cases, graph_dtype
732
+ # might be None
733
+ for node in graph.nodes:
734
+ if OptimizationContext.key in node.meta:
735
+ opt_ctx = node.meta[OptimizationContext.key]
736
+ else:
737
+ opt_ctx = OptimizationContext()
738
+
739
+ opt_ctx.dtype = self.deduce_node_dtype(node)
740
+ node.meta[OptimizationContext.key] = opt_ctx
741
+ if node.target == "output":
742
+ graph_dtype = opt_ctx.dtype
743
+ return graph_dtype
744
+
745
+ def propagate(self) -> Optional[torch.dtype]:
746
+ return self.propagate_graph(self.graphs["root"])
747
+
748
+ @classmethod
749
+ def propagate_loopbody(cls, body: LoopBody) -> Optional[torch.dtype]:
750
+ return cls(body).propagate()
751
+
752
+ @classmethod
753
+ def propagate_scheduler_node(cls, node: SchedulerNode) -> Optional[torch.dtype]:
754
+ from ..loop_body import LoopBody
755
+ from ..scheduler import SchedulerNode
756
+
757
+ assert isinstance(node, SchedulerNode), type(node)
758
+ assert isinstance(node._body, LoopBody), type(node._body)
759
+ return DataTypePropagation.propagate_loopbody(node._body)
760
+
761
+
762
+ class PythonPrinter(_PythonPrinter):
763
+ def doprint(
764
+ self, expr: sympy.Expr, *, simplify: bool = True, p: bool = True
765
+ ) -> str:
766
+ # TODO: why are people passing strings to the printer here :think:
767
+ if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
768
+ expr = V.graph.sizevars.simplify(expr)
769
+ return super().doprint(expr)
770
+
771
+
772
+ class OpDecompositions:
773
+ """
774
+ Decomposes inductor ops
775
+ """
776
+
777
+ @staticmethod
778
+ def identity(value: OpVarT) -> OpVarT:
779
+ # used to trigger cse
780
+ return value
781
+
782
+ @staticmethod
783
+ def reciprocal(x: OpVarT) -> OpVarT:
784
+ return ops.truediv(ops.constant(1, torch.int32), x)
785
+
786
+ @staticmethod
787
+ def square(x: OpVarT) -> OpVarT:
788
+ return ops.mul(x, x)
789
+
790
+ @staticmethod
791
+ def erfc(x: OpVarT) -> OpVarT:
792
+ return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
793
+
794
+ @staticmethod
795
+ def erfcx(x: OpVarT) -> OpVarT:
796
+ return ops.mul(ops.exp(ops.square(x)), ops.erfc(x))
797
+
798
+ @staticmethod
799
+ def expm1(x: OpVarT) -> OpVarT:
800
+ return ops.sub(ops.exp(x), ops.constant(1, torch.float32))
801
+
802
+ @staticmethod
803
+ def log10(x: OpVarT) -> OpVarT:
804
+ return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32))
805
+
806
+ @staticmethod
807
+ def log2(x: OpVarT) -> OpVarT:
808
+ return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32))
809
+
810
+ @staticmethod
811
+ def exp2(x: OpVarT) -> OpVarT:
812
+ return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32)))
813
+
814
+ @staticmethod
815
+ def log1p(x: OpVarT) -> OpVarT:
816
+ return ops.log(ops.add(x, ops.constant(1, torch.int32)))
817
+
818
+ @staticmethod
819
+ def sigmoid(x: OpVarT) -> OpVarT:
820
+ one = ops.constant(1, torch.int32)
821
+ return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
822
+
823
+ @staticmethod
824
+ def relu(x: OpVarT) -> OpVarT:
825
+ return ops.maximum(x, ops.constant(0, torch.int32))
826
+
827
+ @staticmethod
828
+ def fma(x: OpVarT, y: OpVarT, z: OpVarT) -> OpVarT:
829
+ # for backends that don't override this (halide)
830
+ return ops.add(ops.mul(x, y), z)
831
+
832
+ @staticmethod
833
+ def floor_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT:
834
+ return ops.to_dtype(ops.floor(a), dtype)
835
+
836
+ @staticmethod
837
+ def ceil_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT:
838
+ return ops.to_dtype(ops.ceil(a), dtype)
839
+
840
+ @staticmethod
841
+ def trunc_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT:
842
+ return ops.to_dtype(ops.trunc(a), dtype)
843
+
844
+ @staticmethod
845
+ def remainder(a: OpVarT, b: OpVarT) -> OpVarT:
846
+ r = ops.mod(a, b)
847
+ cond = ops.and_(
848
+ ops.ne(r, ops.constant(0, torch.int32)),
849
+ ops.ne(ops.signbit(r), ops.signbit(b)),
850
+ )
851
+ return ops.where(cond, ops.add(r, b), r)
852
+
853
+ @staticmethod
854
+ def round_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT:
855
+ return ops.to_dtype(ops.round(a), dtype)
856
+
857
+
858
+ _RE_PAREN_NOT_NEEDED = re.compile(r"[a-z0-9_.]+|\([^)]*\)|", flags=re.IGNORECASE)
859
+
860
+
861
+ def _all_in_parens(string: str) -> bool:
862
+ if string[0] != "(" or len(string) < 2:
863
+ return False
864
+ count = 1
865
+ for i, char in enumerate(string[1:]):
866
+ if char == "(":
867
+ count += 1
868
+ elif char == ")":
869
+ count -= 1
870
+ if count == 0 and i != len(string) - 2:
871
+ return False
872
+ assert count == 0
873
+ return True
874
+
875
+
876
+ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
877
+ @staticmethod
878
+ def paren(string: OpVarT) -> OpVarT:
879
+ if (
880
+ isinstance(string, CSEVariable)
881
+ or _RE_PAREN_NOT_NEEDED.fullmatch(string)
882
+ or _all_in_parens(string)
883
+ ):
884
+ # don't put extra parens for strings that are already wrapped in parens
885
+ return string
886
+ return f"({string})"
887
+
888
+ @staticmethod
889
+ def constant(value: Union[bool, float, int], dtype: torch.dtype) -> OpVarT:
890
+ return repr(value)
891
+
892
+ @staticmethod
893
+ def bitwise_not(x: OpVarT) -> OpVarT:
894
+ return f"~{OpOverrides.paren(x)}"
895
+
896
+ @staticmethod
897
+ def logical_not(a: OpVarT) -> OpVarT:
898
+ return f"{OpOverrides.paren(a)} == 0"
899
+
900
+ @staticmethod
901
+ def bitwise_and(x: OpVarT, y: OpVarT) -> OpVarT:
902
+ return f"{OpOverrides.paren(x)} & {OpOverrides.paren(y)}"
903
+
904
+ @staticmethod
905
+ def bitwise_or(x: OpVarT, y: OpVarT) -> OpVarT:
906
+ return f"{OpOverrides.paren(x)} | {OpOverrides.paren(y)}"
907
+
908
+ @staticmethod
909
+ def bitwise_xor(x: OpVarT, y: OpVarT) -> OpVarT:
910
+ return f"{OpOverrides.paren(x)} ^ {OpOverrides.paren(y)}"
911
+
912
+ @staticmethod
913
+ def bitwise_left_shift(x: OpVarT, y: OpVarT) -> OpVarT:
914
+ return f"{OpOverrides.paren(x)} << {OpOverrides.paren(y)}"
915
+
916
+ @staticmethod
917
+ def bitwise_right_shift(x: OpVarT, y: OpVarT) -> OpVarT:
918
+ return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}"
919
+
920
+ @staticmethod
921
+ def int_truediv(a: OpVarT, b: OpVarT) -> OpVarT:
922
+ # TODO: this is wrong
923
+ # TODO: an easy bandaid is to generate runtime asserts that it's
924
+ # <= 2**53, which is when this equation is correct
925
+ return ops.truediv(a, b)
926
+
927
+ @staticmethod
928
+ def load_seed(name: str, offset: OpVarT) -> OpVarT:
929
+ return ops.load(name, sympy.Integer(offset))
930
+
931
+ def indirect_indexing(
932
+ self,
933
+ var: OpVarT,
934
+ size: Union[sympy.Expr, int],
935
+ check: bool = True,
936
+ wrap_neg: bool = True,
937
+ ) -> sympy.Symbol:
938
+ return sympy_index_symbol(str(var))
939
+
940
+ def check_bounds(
941
+ self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
942
+ ) -> None:
943
+ raise NotImplementedError(
944
+ f"{type(self).__name__}: check_bounds should be handled by CSEProxy"
945
+ )
946
+
947
+ def load(self, name: str, index: sympy.Expr) -> OpVarT:
948
+ raise NotImplementedError(
949
+ f"{type(self).__name__}: load should be handled by CSEProxy"
950
+ )
951
+
952
+ def store(
953
+ self, name: str, index: sympy.Expr, value: OpVarT, mode: StoreMode = None
954
+ ) -> None:
955
+ raise NotImplementedError(
956
+ f"{type(self).__name__}: store should be handled by CSEProxy"
957
+ )
958
+
959
+ def store_reduction(self, name: str, index: sympy.Expr, value: OpVarT) -> None:
960
+ raise NotImplementedError(
961
+ f"{type(self).__name__}: store_reduction should be handled by CSEProxy"
962
+ )
963
+
964
+ def reduction(
965
+ self,
966
+ dtype: torch.dtype,
967
+ src_dtype: torch.dtype,
968
+ reduction_type: ReductionType,
969
+ value: Union[OpVarT, tuple[OpVarT, ...]],
970
+ ) -> Union[OpVarT, tuple[OpVarT, ...]]:
971
+ raise NotImplementedError(
972
+ f"{type(self).__name__}: reduction should be handled by CSEProxy"
973
+ )
974
+
975
+ def scan(
976
+ self,
977
+ dtypes: tuple[torch.dtype, ...],
978
+ combine_fn: Callable[
979
+ [tuple[OpVarT, ...], tuple[OpVarT, ...]],
980
+ tuple[OpVarT, ...],
981
+ ],
982
+ values: tuple[OpVarT, ...],
983
+ ) -> tuple[OpVarT, ...]:
984
+ raise NotImplementedError(
985
+ f"{type(self).__name__}: scan should be handled by CSEProxy"
986
+ )
987
+
988
+ def sort(
989
+ self,
990
+ dtypes: tuple[torch.dtype, ...],
991
+ values: tuple[OpVarT, ...],
992
+ stable: bool,
993
+ descending: bool,
994
+ ) -> tuple[OpVarT, ...]:
995
+ raise NotImplementedError(
996
+ f"{type(self).__name__}: sort should be handled by CSEProxy"
997
+ )
998
+
999
+ def bucketize(
1000
+ self,
1001
+ values: OpVarT,
1002
+ boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
1003
+ boundary_indices: OpVarT,
1004
+ indexing_dtype: torch.dtype,
1005
+ right: bool,
1006
+ sorter: Optional[tuple[str, sympy.Expr]] = None,
1007
+ sorter_indices: Optional[OpVarT] = None,
1008
+ ) -> OpVarT:
1009
+ raise NotImplementedError(
1010
+ f"{type(self).__name__}: bucketize should be handled by CSEProxy"
1011
+ )
1012
+
1013
+ def halide_clamp(self, value: OpVarT, size: sympy.Expr, check: bool) -> OpVarT:
1014
+ raise NotImplementedError(
1015
+ f"{type(self).__name__}: halide_clamp only implemented for Halide backend"
1016
+ )
1017
+
1018
+ def inline_asm_elementwise(
1019
+ self,
1020
+ *inputs: OpVarT,
1021
+ asm: str,
1022
+ constraints: Optional[str] = None,
1023
+ dtype: torch.dtype = torch.float32,
1024
+ is_pure: bool = True,
1025
+ pack: int = 1,
1026
+ ) -> OpVarT:
1027
+ raise NotImplementedError(
1028
+ f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend"
1029
+ )
1030
+
1031
+ def output(self, *args: OpVarT) -> None:
1032
+ raise AssertionError(
1033
+ f"{type(self).__name__}: ops.output should not appear at codegen time"
1034
+ )
1035
+
1036
+ def placeholder(self, index: int) -> OpVarT:
1037
+ raise AssertionError(
1038
+ f"{type(self).__name__}: ops.placeholder should not appear at codegen time"
1039
+ )
1040
+
1041
+ @staticmethod
1042
+ def _unimplemented(name: str) -> Callable[..., OpVarT]:
1043
+ def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT:
1044
+ raise NotImplementedError(
1045
+ f"{type(self).__name__} does not implement ops.{name}"
1046
+ )
1047
+
1048
+ unimplemented.__name__ = name
1049
+ unimplemented.is_unimplemented = True # type: ignore[attr-defined]
1050
+ return unimplemented
1051
+
1052
+ @classmethod
1053
+ def _is_unimplemented(cls, name: str) -> bool:
1054
+ fn = getattr(cls, name, None)
1055
+ default_fn = getattr(OpsHandler, name, None)
1056
+ return not fn or fn == default_fn or getattr(fn, "is_unimplemented", False)
1057
+
1058
+ @classmethod
1059
+ def _initialize_pointwise_overrides(cls, target: str) -> None:
1060
+ assert target in ("triton", "cpp", "cppvec", "halide", "mps"), target
1061
+
1062
+ for funcname, data in pointwise_overrides_data.items():
1063
+ impl = getattr(data, target)
1064
+ if impl is None:
1065
+ if cls._is_unimplemented(funcname):
1066
+ setattr(cls, funcname, cls._unimplemented(funcname))
1067
+ else:
1068
+ assert funcname not in cls.__dict__, (
1069
+ f"multiple definitions of {funcname} on {cls.__name__}"
1070
+ )
1071
+ impl.__name__ = funcname
1072
+ setattr(cls, funcname, staticmethod(impl))
1073
+
1074
+
1075
+ @dataclasses.dataclass
1076
+ class OverridesData:
1077
+ name: str
1078
+ cpp: Callable[..., str]
1079
+ # None when not impl in libdevice/triton
1080
+ triton: Optional[Callable[..., str]] = None
1081
+ # None when not impl in aten/.../vec
1082
+ cppvec: Optional[Callable[..., str]] = None
1083
+ type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
1084
+ ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
1085
+ )
1086
+ halide: Optional[Callable[..., str]] = None
1087
+ mps: Optional[Callable[..., str]] = None
1088
+
1089
+
1090
+ # NB: if you add a new special function, don't forget to update
1091
+ # torch._inductor.ops_handler too
1092
+ pointwise_overrides_data: dict[str, OverridesData] = dict(
1093
+ airy_ai=OverridesData(
1094
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1095
+ cpp=lambda x: f"airy_ai_forward({x})",
1096
+ name="special_airy_ai",
1097
+ ),
1098
+ bessel_j0=OverridesData(
1099
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1100
+ cpp=lambda x: f"bessel_j0_forward({x})",
1101
+ triton=lambda x: f"libdevice.j0({x})",
1102
+ name="special_bessel_j0",
1103
+ ),
1104
+ bessel_j1=OverridesData(
1105
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1106
+ cpp=lambda x: f"bessel_j1_forward({x})",
1107
+ triton=lambda x: f"libdevice.j1({x})",
1108
+ name="special_bessel_j1",
1109
+ ),
1110
+ bessel_y0=OverridesData(
1111
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1112
+ cpp=lambda x: f"bessel_y0_forward({x})",
1113
+ triton=lambda x: f"libdevice.y0({x})",
1114
+ name="special_bessel_y0",
1115
+ ),
1116
+ bessel_y1=OverridesData(
1117
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1118
+ cpp=lambda x: f"bessel_y1_forward({x})",
1119
+ triton=lambda x: f"libdevice.y1({x})",
1120
+ name="special_bessel_y1",
1121
+ ),
1122
+ digamma=OverridesData(
1123
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1124
+ cpp=lambda x: f"calc_digamma({x})",
1125
+ cppvec=lambda x: f"{x}.digamma()",
1126
+ name="digamma",
1127
+ ),
1128
+ # no cpp nor triton implementation for entr, it is defined as decomposition
1129
+ # erf, erfc
1130
+ erfcx=OverridesData(
1131
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1132
+ cpp=lambda x: f"calc_erfcx({x})",
1133
+ triton=lambda x: f"libdevice.erfcx({x})",
1134
+ name="special_erfcx",
1135
+ ),
1136
+ fma=OverridesData(
1137
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1138
+ cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})",
1139
+ cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})",
1140
+ triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})",
1141
+ name="fma",
1142
+ ),
1143
+ # erfinv, exp2, expit, gammaln
1144
+ igamma=OverridesData(
1145
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1146
+ cpp=lambda x, y: f"calc_igamma({x}, {y})",
1147
+ name="igamma",
1148
+ ),
1149
+ igammac=OverridesData(
1150
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1151
+ cpp=lambda x, y: f"calc_igammac({x}, {y})",
1152
+ name="igammac",
1153
+ ),
1154
+ gammainc=OverridesData(
1155
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1156
+ cpp=lambda x, y: f"calc_igamma({x}, {y})",
1157
+ name="special_gammainc",
1158
+ ),
1159
+ gammaincc=OverridesData(
1160
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1161
+ cpp=lambda x, y: f"calc_igammac({x}, {y})",
1162
+ name="special_gammaincc",
1163
+ ),
1164
+ i0=OverridesData(
1165
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1166
+ cpp=lambda x: f"calc_i0({x})",
1167
+ triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
1168
+ cppvec=lambda x: f"{x}.i0()",
1169
+ name="i0",
1170
+ ),
1171
+ i0e=OverridesData(
1172
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1173
+ cpp=lambda x: f"calc_i0e({x})",
1174
+ cppvec=lambda x: f"{x}.i0e()",
1175
+ name="special_i0e",
1176
+ ),
1177
+ i1=OverridesData(
1178
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1179
+ cpp=lambda x: f"calc_i1({x})",
1180
+ triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
1181
+ name="special_i1",
1182
+ ),
1183
+ i1e=OverridesData(
1184
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1185
+ cpp=lambda x: f"calc_i1e({x})",
1186
+ name="special_i1e",
1187
+ ),
1188
+ log_ndtr=OverridesData(
1189
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1190
+ cpp=lambda x: f"calc_log_ndtr({x})",
1191
+ name="special_log_ndtr",
1192
+ ),
1193
+ # logit
1194
+ modified_bessel_i0=OverridesData(
1195
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1196
+ cpp=lambda x: f"modified_bessel_i0_forward({x})",
1197
+ triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
1198
+ name="special_modified_bessel_i0",
1199
+ ),
1200
+ modified_bessel_i1=OverridesData(
1201
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1202
+ cpp=lambda x: f"modified_bessel_i1_forward({x})",
1203
+ triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
1204
+ name="special_modified_bessel_i1",
1205
+ ),
1206
+ modified_bessel_k0=OverridesData(
1207
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1208
+ cpp=lambda x: f"modified_bessel_k0_forward({x})",
1209
+ name="special_modified_bessel_k0",
1210
+ ),
1211
+ modified_bessel_k1=OverridesData(
1212
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1213
+ cpp=lambda x: f"modified_bessel_k1_forward({x})",
1214
+ name="special_modified_bessel_k1",
1215
+ ),
1216
+ # multigamma
1217
+ ndtr=OverridesData(
1218
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1219
+ cpp=lambda x: f"calc_ndtr({x})",
1220
+ name="special_ndtr",
1221
+ ),
1222
+ ndtri=OverridesData(
1223
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1224
+ cpp=lambda x: f"calc_ndtri({x})",
1225
+ name="special_ndtri",
1226
+ ),
1227
+ polygamma=OverridesData(
1228
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1229
+ cpp=lambda x,
1230
+ y: f"{x} == 0 ? calc_digamma({y}) : ({x} == 1 ? trigamma({y}) : calc_polygamma({y}, {x}))",
1231
+ name="polygamma",
1232
+ ),
1233
+ # psi - alias to digamma
1234
+ # round
1235
+ scaled_modified_bessel_k0=OverridesData(
1236
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1237
+ cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})",
1238
+ name="special_scaled_modified_bessel_k0",
1239
+ ),
1240
+ scaled_modified_bessel_k1=OverridesData(
1241
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1242
+ cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})",
1243
+ name="special_scaled_modified_bessel_k1",
1244
+ ),
1245
+ # sinc
1246
+ spherical_bessel_j0=OverridesData(
1247
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1248
+ cpp=lambda x: f"spherical_bessel_j0_forward({x})",
1249
+ name="special_spherical_bessel_j0",
1250
+ ),
1251
+ zeta=OverridesData(
1252
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1253
+ cpp=lambda x, y: f"zeta({x}, {y})",
1254
+ name="special_zeta",
1255
+ ),
1256
+ chebyshev_polynomial_t=OverridesData(
1257
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1258
+ cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})",
1259
+ name="special_chebyshev_polynomial_t",
1260
+ ),
1261
+ chebyshev_polynomial_u=OverridesData(
1262
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1263
+ cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})",
1264
+ name="special_chebyshev_polynomial_u",
1265
+ ),
1266
+ chebyshev_polynomial_v=OverridesData(
1267
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1268
+ cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})",
1269
+ name="special_chebyshev_polynomial_v",
1270
+ ),
1271
+ chebyshev_polynomial_w=OverridesData(
1272
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1273
+ cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})",
1274
+ name="special_chebyshev_polynomial_w",
1275
+ ),
1276
+ legendre_polynomial_p=OverridesData(
1277
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1278
+ cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})",
1279
+ name="special_legendre_polynomial_p",
1280
+ ),
1281
+ shifted_chebyshev_polynomial_t=OverridesData(
1282
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1283
+ cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})",
1284
+ name="special_shifted_chebyshev_polynomial_t",
1285
+ ),
1286
+ shifted_chebyshev_polynomial_u=OverridesData(
1287
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1288
+ cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})",
1289
+ name="special_shifted_chebyshev_polynomial_u",
1290
+ ),
1291
+ shifted_chebyshev_polynomial_v=OverridesData(
1292
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1293
+ cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})",
1294
+ name="special_shifted_chebyshev_polynomial_v",
1295
+ ),
1296
+ shifted_chebyshev_polynomial_w=OverridesData(
1297
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1298
+ cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})",
1299
+ name="special_shifted_chebyshev_polynomial_w",
1300
+ ),
1301
+ hermite_polynomial_h=OverridesData(
1302
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1303
+ cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})",
1304
+ name="special_hermite_polynomial_h",
1305
+ ),
1306
+ hermite_polynomial_he=OverridesData(
1307
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1308
+ cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})",
1309
+ name="special_hermite_polynomial_he",
1310
+ ),
1311
+ laguerre_polynomial_l=OverridesData(
1312
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1313
+ cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})",
1314
+ name="special_laguerre_polynomial_l",
1315
+ ),
1316
+ )
1317
+
1318
+
1319
+ def is_buffer_removed(name: str) -> bool:
1320
+ return any(
1321
+ name in x
1322
+ for x in (
1323
+ V.graph.removed_buffers,
1324
+ V.kernel.removed_buffers,
1325
+ V.graph.inplaced_to_remove,
1326
+ V.kernel.inplaced_to_remove,
1327
+ )
1328
+ )
1329
+
1330
+
1331
+ class DeferredLine(DeferredLineBase):
1332
+ """A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
1333
+
1334
+ def __init__(self, name: str, line: str):
1335
+ super().__init__(line)
1336
+ self.name = name
1337
+ assert not isinstance(line, DeferredLineBase)
1338
+
1339
+ def __call__(self) -> Optional[str]:
1340
+ if not is_buffer_removed(self.name):
1341
+ return self.line
1342
+ return None
1343
+
1344
+ def _new_line(self, line: str) -> DeferredLine:
1345
+ return DeferredLine(self.name, line)
1346
+
1347
+
1348
+ class BracesBuffer(IndentedBuffer):
1349
+ def indent(self, offset: int = 1) -> contextlib.AbstractContextManager[None]:
1350
+ @contextlib.contextmanager
1351
+ def ctx() -> Iterator[None]:
1352
+ for _ in range(offset):
1353
+ self.writeline("{")
1354
+ self._indent += 1
1355
+ for _ in range(-offset):
1356
+ self._indent -= 1
1357
+ self.writeline("}")
1358
+ yield
1359
+ for _ in range(-offset):
1360
+ self.writeline("{")
1361
+ self._indent += 1
1362
+ for _ in range(offset):
1363
+ self._indent -= 1
1364
+ self.writeline("}")
1365
+
1366
+ return ctx()
1367
+
1368
+
1369
+ class InplacedBuffer(NamedTuple):
1370
+ inner_name: str
1371
+ other_names: list[str]
1372
+
1373
+
1374
+ @dataclasses.dataclass
1375
+ class ArgName:
1376
+ name: str
1377
+ # is_constexpr=True is used to attach a " : tl.constexpr" into the argument list
1378
+ is_constexpr: bool = False
1379
+
1380
+ def full_name(self) -> str:
1381
+ return f"{self.name}{' : tl.constexpr' if self.is_constexpr else ''}"
1382
+
1383
+
1384
+ class RemovedArg:
1385
+ def __str__(self) -> str:
1386
+ return "REMOVED"
1387
+
1388
+
1389
+ REMOVED = RemovedArg()
1390
+
1391
+
1392
+ class KernelArgs:
1393
+ @staticmethod
1394
+ def _lookup(
1395
+ prefix: str,
1396
+ odict: Union[dict[_T, Union[str, RemovedArg]], dict[_T, str]],
1397
+ name: _T,
1398
+ ) -> str:
1399
+ result: Union[str, RemovedArg] = odict.get(name, REMOVED)
1400
+ if isinstance(result, RemovedArg):
1401
+ odict[name] = new_result = f"{prefix}{len(odict)}"
1402
+ return new_result
1403
+ return result
1404
+
1405
+ def __init__(self) -> None:
1406
+ self.input_buffers: dict[str, str] = {}
1407
+ self.output_buffers: dict[str, Union[str, RemovedArg]] = {}
1408
+ self.inplace_buffers: dict[str, Union[InplacedBuffer, RemovedArg]] = {}
1409
+ self.sizevars: dict[sympy.Expr, str] = {}
1410
+ self.workspace_args: list[WorkspaceArg] = []
1411
+
1412
+ def __repr__(self) -> str:
1413
+ return "KernelArgs({})".format(
1414
+ ", ".join(
1415
+ map(
1416
+ repr,
1417
+ [
1418
+ self.input_buffers,
1419
+ self.output_buffers,
1420
+ self.inplace_buffers,
1421
+ self.sizevars,
1422
+ ],
1423
+ )
1424
+ )
1425
+ )
1426
+
1427
+ @staticmethod
1428
+ def _buffer_is_marked_removed(name: Any) -> bool:
1429
+ # this function is needed by MTIA
1430
+ return isinstance(name, RemovedArg)
1431
+
1432
+ def input(self, name: str) -> str:
1433
+ if V.graph.scheduler:
1434
+ name = V.graph.scheduler.mutation_real_name.get(name, name)
1435
+ assert name not in V.graph.removed_buffers, name
1436
+ if name in self.output_buffers:
1437
+ return cast(str, self.output_buffers[name])
1438
+ if name in self.inplace_buffers:
1439
+ return cast(InplacedBuffer, self.inplace_buffers[name]).inner_name
1440
+ if name.startswith("seed"):
1441
+ return self._lookup("seed", self.input_buffers, name)
1442
+ return self._lookup("in_ptr", self.input_buffers, name)
1443
+
1444
+ def output(self, name: str) -> str:
1445
+ if V.graph.scheduler:
1446
+ name = V.graph.scheduler.mutation_real_name.get(name, name)
1447
+ assert name not in V.graph.removed_buffers, name
1448
+ if name in self.inplace_buffers:
1449
+ return cast(InplacedBuffer, self.inplace_buffers[name]).inner_name
1450
+ return self._lookup("out_ptr", self.output_buffers, name)
1451
+
1452
+ def make_inplace(self, input_name: str, output_name: str) -> None:
1453
+ if input_name in V.graph.unaligned_buffers:
1454
+ V.graph.unaligned_buffers.add(output_name)
1455
+ assert output_name not in self.inplace_buffers, output_name
1456
+ if input_name in self.inplace_buffers:
1457
+ buf = self.inplace_buffers[input_name]
1458
+ assert not isinstance(buf, RemovedArg)
1459
+ buf.other_names.append(output_name)
1460
+ self.inplace_buffers[output_name] = buf
1461
+ else:
1462
+ alive_buffers = [
1463
+ val
1464
+ for val in self.inplace_buffers.values()
1465
+ if not isinstance(val, RemovedArg)
1466
+ ]
1467
+ removed_buffers = [
1468
+ val
1469
+ for val in self.inplace_buffers.values()
1470
+ if isinstance(val, RemovedArg)
1471
+ ]
1472
+ inplace_buffer_idx = len(unique(alive_buffers)) + len(removed_buffers)
1473
+ buf = InplacedBuffer(
1474
+ f"in_out_ptr{inplace_buffer_idx}",
1475
+ [input_name, output_name],
1476
+ )
1477
+ self.inplace_buffers[input_name] = buf
1478
+ self.inplace_buffers[output_name] = buf
1479
+
1480
+ def workspace(self, nbytes: sympy.Expr, zero_fill: bool) -> tuple[str, int]:
1481
+ """
1482
+ Allocate or extend a workspace buffer of nbytes bytes.
1483
+
1484
+ This function manages the allocation of a workspace buffer. It either creates
1485
+ a new WorkspaceArg or extends an existing one.
1486
+
1487
+ Note:
1488
+ - Calling this function will in-place mutate the args by adding or updating
1489
+ a WorkspaceArg.
1490
+ - The codegen for generating the Python argdefs and call_defs will check
1491
+ this field and allocate the buffer accordingly.
1492
+ - A new argument "ws_ptr" will be present in the generated code.
1493
+
1494
+ Args:
1495
+ nbytes (sympy.Expr): The number of bytes to allocate.
1496
+ zero_fill (bool): Whether to initialize the buffer to zero.
1497
+
1498
+ Returns:
1499
+ Tuple[str, int]: A tuple containing:
1500
+ - "ws_ptr": A string identifier for the workspace pointer.
1501
+ - offset: An integer representing the byte offset in the workspace.
1502
+ """
1503
+ arg = WorkspaceArg(
1504
+ count=nbytes,
1505
+ zero_mode=WorkspaceZeroMode.from_bool(zero_fill),
1506
+ device=V.graph.get_current_device_or_throw(),
1507
+ outer_name=WorkspaceArg.unique_name(),
1508
+ )
1509
+ for i, existing_arg in enumerate(self.workspace_args):
1510
+ if WorkspaceArg.can_join(existing_arg, arg):
1511
+ offset = existing_arg.count
1512
+ self.workspace_args[i] = WorkspaceArg.join(existing_arg, arg)
1513
+ return existing_arg.inner_name, offset
1514
+ assert (
1515
+ existing_arg.inner_name != arg.inner_name
1516
+ and existing_arg.outer_name != arg.outer_name
1517
+ ), existing_arg
1518
+ self.workspace_args.append(arg)
1519
+ return arg.inner_name, 0
1520
+
1521
+ def semaphores(self, min_size: sympy.Expr) -> str:
1522
+ """
1523
+ Lazily allocate a graph-wide semaphores buffer with at least min_size. This is a single buffer shared by
1524
+ all kernels and zero initialized once at graph start. Each kernel must leave the buffer zeroed on exit.
1525
+
1526
+ Warning: multiple calls to this function will return the same buffer.
1527
+
1528
+ Args:
1529
+ min_size: the number of int32 semaphores required
1530
+
1531
+ Returns:
1532
+ name of the semaphores buffer
1533
+ """
1534
+ current_device = V.graph.get_current_device_or_throw()
1535
+ arg = WorkspaceArg(
1536
+ count=min_size,
1537
+ zero_mode=WorkspaceZeroMode.ZERO_PER_GRAPH,
1538
+ dtype=torch.uint32,
1539
+ inner_name="sem_ptr",
1540
+ outer_name=f"semaphores_{current_device.type}_{current_device.index}",
1541
+ device=current_device,
1542
+ )
1543
+ for existing_arg in self.workspace_args:
1544
+ if existing_arg.inner_name == arg.inner_name:
1545
+ assert arg == existing_arg, (arg, existing_arg)
1546
+ self.workspace_args.append(arg)
1547
+ return arg.inner_name
1548
+
1549
+ def seed_offset(self, name: str, value: int) -> str:
1550
+ assert isinstance(value, int), (type(value), value)
1551
+ # here we are lifting a constant integer into an arg to the kernel to try to get additional cache hits
1552
+ value = sympy.Integer(value)
1553
+ if value in self.sizevars:
1554
+ return self.sizevars[value]
1555
+ if name in self.sizevars.values():
1556
+ name = (
1557
+ f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
1558
+ )
1559
+ self.sizevars[value] = name
1560
+ return name
1561
+
1562
+ def size(self, name: sympy.Symbol) -> str:
1563
+ assert isinstance(name, sympy.Symbol), (type(name), name)
1564
+ if name.name == "seed":
1565
+ self.sizevars[name] = "seed" # don't manage the name of seeds
1566
+ return "seed"
1567
+ return self._lookup("ks", self.sizevars, name)
1568
+
1569
+ def call_names(self) -> Iterator[str]:
1570
+ return chain(
1571
+ self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
1572
+ )
1573
+
1574
+ def arg_name(self, name: str) -> Optional[str]:
1575
+ """
1576
+ Returns inner name of a given outer name.
1577
+ """
1578
+ inplaced = self.inplace_buffers.get(name, None)
1579
+ if inplaced is not None and not isinstance(inplaced, RemovedArg):
1580
+ return inplaced.inner_name
1581
+ output_name = self.output_buffers.get(name, None)
1582
+ if output_name is not None and not isinstance(output_name, RemovedArg):
1583
+ return output_name
1584
+ return self.input_buffers.get(name, None)
1585
+
1586
+ def wrap_ptr_arg(self, buf: str, dtype: torch.dtype) -> str:
1587
+ return buf
1588
+
1589
+ def wrap_size_arg(self, size: SymbolLike) -> str:
1590
+ return str(size)
1591
+
1592
+ def cpp_argdefs(
1593
+ self, dtype_to_cpp_type: Optional[dict[torch.dtype, str]] = None
1594
+ ) -> tuple[list[str], list[str], list[str]]:
1595
+ from .cpp_utils import INDEX_TYPE
1596
+
1597
+ if dtype_to_cpp_type is None:
1598
+ from .cpp_utils import DTYPE_TO_CPP
1599
+
1600
+ dtype_to_cpp_type = DTYPE_TO_CPP
1601
+
1602
+ call_args = []
1603
+ arg_defs = []
1604
+ arg_types = []
1605
+ for inplaced in unique(self.inplace_buffers.values()):
1606
+ if isinstance(inplaced, RemovedArg):
1607
+ continue
1608
+ outer = inplaced.other_names[-1]
1609
+ inner = inplaced.inner_name
1610
+ dtype = V.graph.get_dtype(outer)
1611
+ cpp_dtype = dtype_to_cpp_type[dtype]
1612
+ arg_defs.append(f"{cpp_dtype}* {inner}")
1613
+ call_args.append(self.wrap_ptr_arg(outer, dtype))
1614
+ arg_types.append(f"{cpp_dtype}*")
1615
+ for outer, inner in self.input_buffers.items():
1616
+ if outer in self.inplace_buffers:
1617
+ continue
1618
+ dtype = V.graph.get_dtype(outer)
1619
+ cpp_dtype = dtype_to_cpp_type[dtype]
1620
+ arg_defs.append(f"const {cpp_dtype}* {inner}")
1621
+ call_args.append(self.wrap_ptr_arg(outer, dtype))
1622
+ arg_types.append(f"const {cpp_dtype}*")
1623
+ for outer, maybe_inner in self.output_buffers.items():
1624
+ if outer in self.inplace_buffers or isinstance(maybe_inner, RemovedArg):
1625
+ continue
1626
+ dtype = V.graph.get_dtype(outer)
1627
+ cpp_dtype = dtype_to_cpp_type[dtype]
1628
+ arg_defs.append(f"{cpp_dtype}* {maybe_inner}")
1629
+ call_args.append(self.wrap_ptr_arg(outer, dtype))
1630
+ arg_types.append(f"{cpp_dtype}*")
1631
+ for outer, inner in self.sizevars.items():
1632
+ arg_defs.append(f"const {INDEX_TYPE} {inner}")
1633
+ call_args.append(self.wrap_size_arg(outer))
1634
+ arg_types.append(f"const {INDEX_TYPE}")
1635
+ if V.graph.wrapper_code:
1636
+ V.graph.wrapper_code.ensure_size_computed(outer)
1637
+ assert not self.workspace_args, "Workspace not supported on CPU "
1638
+ return arg_defs, call_args, arg_types
1639
+
1640
+ def python_argdefs(
1641
+ self,
1642
+ ) -> tuple[list[ArgName], list[str], list[KernelArgType], list[Any]]:
1643
+ arg_defs: list[ArgName] = []
1644
+ call_args: list[str] = []
1645
+ arg_types: list[Any] = []
1646
+ precompile_args: list[KernelArgType] = []
1647
+ for inplaced in unique(self.inplace_buffers.values()):
1648
+ if isinstance(inplaced, RemovedArg):
1649
+ continue
1650
+ arg_defs.append(ArgName(inplaced.inner_name))
1651
+ call_args.append(inplaced.other_names[-1])
1652
+ arg_types.append(V.graph.get_dtype(inplaced.other_names[-1]))
1653
+ precompile_args.append(
1654
+ TensorArg(
1655
+ name=inplaced.inner_name,
1656
+ buffer=inplaced.other_names[-1],
1657
+ dtype=V.graph.get_dtype(inplaced.other_names[-1]),
1658
+ )
1659
+ )
1660
+ for outer, inner in chain(
1661
+ self.input_buffers.items(), self.output_buffers.items()
1662
+ ):
1663
+ if outer in self.inplace_buffers or isinstance(inner, RemovedArg):
1664
+ continue
1665
+ arg_defs.append(ArgName(inner))
1666
+ call_args.append(outer)
1667
+ arg_types.append(V.graph.get_dtype(outer))
1668
+ precompile_args.append(
1669
+ TensorArg(
1670
+ name=inner,
1671
+ buffer=outer,
1672
+ dtype=V.graph.get_dtype(outer),
1673
+ )
1674
+ )
1675
+ for outer, inner in self.sizevars.items():
1676
+ arg_defs.append(ArgName(inner))
1677
+ call_args.append(outer)
1678
+ arg_types.append(type(outer))
1679
+ precompile_args.append(SizeArg(inner, outer))
1680
+ if V.graph.wrapper_code:
1681
+ V.graph.wrapper_code.ensure_size_computed(outer)
1682
+ for arg in self.workspace_args:
1683
+ arg_defs.append(ArgName(arg.inner_name))
1684
+ call_args.append(arg.outer_name)
1685
+ precompile_args.append(arg)
1686
+ arg_types.append(arg.dtype)
1687
+ return arg_defs, call_args, precompile_args, arg_types
1688
+
1689
+ def aliases(self) -> Iterator[tuple[str, str]]:
1690
+ for inplaced in unique(self.inplace_buffers.values()):
1691
+ if isinstance(inplaced, RemovedArg):
1692
+ continue
1693
+ for other in inplaced.other_names:
1694
+ if (
1695
+ other in V.graph.inplaced_to_remove
1696
+ or other in V.kernel.inplaced_to_remove
1697
+ ):
1698
+ continue
1699
+ if other in self.input_buffers:
1700
+ yield self.input_buffers[other], inplaced.inner_name
1701
+ if other in self.output_buffers:
1702
+ yield cast(str, self.output_buffers[other]), inplaced.inner_name
1703
+
1704
+ def is_removed(self, name: str) -> bool:
1705
+ return isinstance(
1706
+ self.output_buffers.get(name, REMOVED), RemovedArg
1707
+ ) and isinstance(self.inplace_buffers.get(name, REMOVED), RemovedArg)
1708
+
1709
+ # Includes inplace buffers, excludes removed buffers. Essentially,
1710
+ # after you do a call into this kernel, which buffers actually contain
1711
+ # updated data? Modeled off of python_argdefs.
1712
+ def live_output_buffers(self) -> OrderedSet[str]:
1713
+ live_outs: OrderedSet[str] = OrderedSet()
1714
+ for inplaced in unique(self.inplace_buffers.values()):
1715
+ if isinstance(inplaced, RemovedArg):
1716
+ continue
1717
+ live_outs.add(inplaced.other_names[-1])
1718
+ for outer, inner in self.output_buffers.items():
1719
+ if outer in self.inplace_buffers or isinstance(inner, RemovedArg):
1720
+ continue
1721
+ live_outs.add(outer)
1722
+ return live_outs
1723
+
1724
+
1725
+ class CSEVariable:
1726
+ """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
1727
+ To do so, the backends can simply overload `Kernel.create_cse_var`
1728
+ The "CSEVariable.update_on_args" method gives you a hook for annotations
1729
+ See example of TritonCSEVariable in triton.py
1730
+ """
1731
+
1732
+ def __init__(
1733
+ self,
1734
+ name: str,
1735
+ bounds: ValueRanges[Any],
1736
+ dtype: Optional[torch.dtype] = None,
1737
+ ):
1738
+ super().__init__()
1739
+ assert isinstance(bounds, ValueRanges), type(bounds)
1740
+ self.name = name
1741
+ self.bounds = bounds
1742
+ self.use_count = 1 # track how many times this expression is used
1743
+ self.dtype = dtype
1744
+
1745
+ def __str__(self) -> str:
1746
+ return self.name
1747
+
1748
+ def __hash__(self) -> int:
1749
+ return hash(self.name)
1750
+
1751
+ def __eq__(self, other: object) -> bool:
1752
+ return isinstance(other, CSEVariable) and other.name == self.name
1753
+
1754
+ def update_on_args(self, name: str, args: Any, kwargs: Any) -> None:
1755
+ pass
1756
+
1757
+ def __repr__(self) -> str:
1758
+ return f"{self.__class__.__name__}({self.name!r})"
1759
+
1760
+
1761
+ AugmentedKeyT = TypeVar("AugmentedKeyT", default=str)
1762
+ CSEVariableType = TypeVar("CSEVariableType", bound=CSEVariable, default=CSEVariable)
1763
+
1764
+ if TYPE_CHECKING:
1765
+ ReductionCacheKey = tuple[
1766
+ torch.dtype,
1767
+ ReductionType,
1768
+ Union[CSEVariable, tuple[CSEVariable, ...]],
1769
+ ]
1770
+
1771
+
1772
+ class CSE(Generic[CSEVariableType, AugmentedKeyT]):
1773
+ """Common subexpression elimination"""
1774
+
1775
+ def __init__(
1776
+ self,
1777
+ prefix: str = "",
1778
+ suffix: str = "",
1779
+ name_prefix: str = "tmp",
1780
+ iter_buffers: Optional[itertools.count[int]] = None,
1781
+ store_cache: Optional[MutableMapping[str, CSEVariableType]] = None,
1782
+ reduction_cache: Optional[
1783
+ MutableMapping[ReductionCacheKey, CSEVariableType]
1784
+ ] = None,
1785
+ varname_map: Optional[dict[str, CSEVariableType]] = None,
1786
+ ):
1787
+ self.prefix = prefix
1788
+ self.suffix = suffix
1789
+ self._cache: MutableMapping[AugmentedKeyT, CSEVariableType] = {}
1790
+ self.name_prefix = name_prefix
1791
+ self.store_cache: MutableMapping[str, CSEVariableType] = store_cache or {}
1792
+ self.reduction_cache: MutableMapping[ReductionCacheKey, CSEVariableType] = (
1793
+ reduction_cache or {}
1794
+ )
1795
+ self.iter_buffer_ids: itertools.count[int] = iter_buffers or itertools.count()
1796
+ self.invalidated_stores: OrderedSet[str] = OrderedSet()
1797
+ self.varname_map: dict[str, CSEVariableType] = varname_map or {}
1798
+
1799
+ def invalidate(self, keep_vars: OrderedSet[CSEVariable]) -> None:
1800
+ for name, tmp in [*self.store_cache.items()]:
1801
+ if tmp not in keep_vars:
1802
+ del self.store_cache[name]
1803
+ self.invalidated_stores.add(name)
1804
+ if keep_vars:
1805
+ self._cache = {k: v for k, v in self._cache.items() if v in keep_vars}
1806
+ else:
1807
+ self._cache = {}
1808
+
1809
+ def clone(self) -> Self:
1810
+ return type(self)(
1811
+ prefix=self.prefix,
1812
+ suffix=self.suffix,
1813
+ name_prefix=self.name_prefix,
1814
+ iter_buffers=self.iter_buffer_ids,
1815
+ store_cache=self.store_cache,
1816
+ varname_map=self.varname_map,
1817
+ reduction_cache=self.reduction_cache,
1818
+ )
1819
+
1820
+ def scoped_copy(self) -> Self:
1821
+ """Return a copy of using ScopedDict so changes to *_cache aren't visible in self"""
1822
+ new_cse = self.clone()
1823
+ new_cse._cache = ScopedDict(self._cache)
1824
+ new_cse.reduction_cache = ScopedDict(self.reduction_cache)
1825
+ new_cse.store_cache = ScopedDict(self.store_cache)
1826
+ return new_cse
1827
+
1828
+ def augment_key(self, cache_key: str) -> AugmentedKeyT:
1829
+ "Override this method to augment cache key with backend specifics"
1830
+ return cast(AugmentedKeyT, cache_key)
1831
+
1832
+ def put(self, cache_key: str, val: CSEVariableType) -> None:
1833
+ self._cache[self.augment_key(cache_key)] = val
1834
+
1835
+ def contains(self, cache_key: str) -> bool:
1836
+ return self.augment_key(cache_key) in self._cache
1837
+
1838
+ def try_get(self, cache_key: str) -> Optional[CSEVariableType]:
1839
+ return self._cache.get(self.augment_key(cache_key), None)
1840
+
1841
+ def get(self, cache_key: str) -> CSEVariableType:
1842
+ return self._cache[self.augment_key(cache_key)]
1843
+
1844
+ def generate(
1845
+ self,
1846
+ buffer: IndentedBuffer,
1847
+ expr: Union[str, CSEVariable, OpsValue, IndentedBuffer, DeferredLineBase],
1848
+ *,
1849
+ bounds: ValueRanges[Any] = ValueRanges.unknown(),
1850
+ write: bool = True,
1851
+ assignment: bool = True,
1852
+ dtype: Optional[torch.dtype] = None,
1853
+ ) -> CSEVariableType:
1854
+ if isinstance(expr, OpsValue):
1855
+ expr = expr.value
1856
+
1857
+ assert write or assignment
1858
+ if isinstance(expr, CSEVariable):
1859
+ # If the expressions were always created with all the information, we could
1860
+ # assert expr.bounds == bounds, but sometimes the expression is created
1861
+ # with the loose ValueRanges.unknown(), so we need to tighten the bounds
1862
+ expr.bounds = expr.bounds.tighten(bounds)
1863
+ expr.use_count += 1
1864
+ return cast(CSEVariableType, expr)
1865
+ elif isinstance(expr, IndentedBuffer):
1866
+ cache_key = expr.getvalue()
1867
+ elif isinstance(expr, DeferredLineBase):
1868
+ cache_key = expr.line
1869
+ else:
1870
+ assert isinstance(expr, str)
1871
+ cache_key = expr
1872
+ var = self.try_get(cache_key)
1873
+ if not var:
1874
+ var = self.newvar(bounds, dtype)
1875
+ self.put(cache_key, var)
1876
+ if write:
1877
+ if V.kernel.current_node:
1878
+ V.kernel.current_node.codegen_originating_info(
1879
+ buffer, only_once=True
1880
+ )
1881
+ if isinstance(expr, IndentedBuffer):
1882
+ if assignment:
1883
+ buffer.writeline(f"{self.prefix}{var} =")
1884
+ buffer.splice(expr)
1885
+ buffer.writeline(self.suffix)
1886
+ elif isinstance(expr, DeferredLineBase):
1887
+ assert assignment
1888
+ buffer.writeline(
1889
+ expr._new_line(f"{self.prefix}{var} = {expr.line}{self.suffix}")
1890
+ )
1891
+ else:
1892
+ if assignment:
1893
+ line = f"{self.prefix}{var} = {expr}{self.suffix}"
1894
+ else:
1895
+ line = f"{expr}{self.suffix}"
1896
+ buffer.writeline(line)
1897
+
1898
+ # cpp backend cannot determine is_vec at this point
1899
+ if (
1900
+ assignment
1901
+ and (
1902
+ config.test_configs.runtime_triton_dtype_assert
1903
+ or config.test_configs.static_cpp_dtype_assert
1904
+ )
1905
+ and dtype is not None
1906
+ and get_current_backend() != "cpp"
1907
+ ):
1908
+ check_dtype(buffer, var, dtype)
1909
+
1910
+ else:
1911
+ var.bounds = var.bounds.tighten(bounds)
1912
+ var.use_count += 1
1913
+
1914
+ return var
1915
+
1916
+ def newvar(
1917
+ self,
1918
+ bounds: ValueRanges[Any] = ValueRanges.unknown(),
1919
+ dtype: Optional[torch.dtype] = None,
1920
+ ) -> CSEVariableType:
1921
+ var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
1922
+ var = V.kernel.create_cse_var(var_name, bounds, dtype)
1923
+ self.varname_map[var_name] = var
1924
+ return var
1925
+
1926
+ def namedvar(
1927
+ self,
1928
+ name: str,
1929
+ bounds: ValueRanges[Any] = ValueRanges.unknown(),
1930
+ dtype: Optional[torch.dtype] = None,
1931
+ ) -> CSEVariableType:
1932
+ torch._check_value(
1933
+ name not in self.varname_map, lambda: f"duplicate name: {name}"
1934
+ )
1935
+ var = V.kernel.create_cse_var(name, bounds, dtype)
1936
+ self.varname_map[name] = var
1937
+ return var
1938
+
1939
+
1940
+ class CodeGen:
1941
+ def __init__(self) -> None:
1942
+ super().__init__()
1943
+ self.exit_stack = contextlib.ExitStack()
1944
+
1945
+ def __enter__(self) -> Self:
1946
+ self.exit_stack.__enter__()
1947
+ return self
1948
+
1949
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
1950
+ self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
1951
+
1952
+
1953
+ class Kernel(CodeGen, Generic[CSEVariableType]):
1954
+ newvar_prefix: str = ""
1955
+ suffix: str = ""
1956
+ overrides: Optional[Callable[[], OpsHandler[Any]]] = None
1957
+
1958
+ def __init__(
1959
+ self, args: Optional[KernelArgs] = None, increase_kernel_count: bool = True
1960
+ ) -> None:
1961
+ super().__init__()
1962
+ if increase_kernel_count:
1963
+ metrics.generated_kernel_count += 1
1964
+ self.args = args or KernelArgs()
1965
+ self.loads = IndentedBuffer()
1966
+ self.compute = IndentedBuffer()
1967
+ self.stores = IndentedBuffer()
1968
+
1969
+ self.num_load = 0
1970
+ self.num_reduction = 0
1971
+
1972
+ self.cse: CSE[CSEVariableType, Any] = CSE(self.newvar_prefix, self.suffix)
1973
+ self.must_keep_buffers: OrderedSet[str] = OrderedSet()
1974
+ self.store_buffer_names: OrderedSet[str] = OrderedSet()
1975
+ self._load_mask: Optional[str] = None
1976
+ self._load_other: Union[None, int, float] = None
1977
+ # OrderedSet in set_current_node
1978
+ self.current_node: Optional[SchedulerNode] = None
1979
+ self.node_to_bounds: Optional[dict[torch.fx.Node, ValueRanges[Any]]] = None
1980
+
1981
+ self.removed_buffers: OrderedSet[str] = OrderedSet()
1982
+ self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
1983
+
1984
+ # key: the buffer to write
1985
+ # value: the buffer to read and whose memory can be reused for
1986
+ # the buffer specified by key
1987
+ self.inplace_update_buffers: dict[str, str] = {}
1988
+ # Set minimum number of elements processed per thread.
1989
+ self.min_elem_per_thread = 1
1990
+ self.kernel_name: Optional[str] = None
1991
+
1992
+ @contextlib.contextmanager
1993
+ def set_current_node(self, node: SchedulerNode) -> Iterator[None]:
1994
+ prior = self.current_node
1995
+ self.current_node = node
1996
+ self.node_to_bounds = node._body.bounds().get_bounds()
1997
+ try:
1998
+ yield
1999
+ finally:
2000
+ self.current_node = prior
2001
+
2002
+ @contextlib.contextmanager
2003
+ def swap_buffers(
2004
+ self,
2005
+ lb: IndentedBuffer,
2006
+ cb: Optional[IndentedBuffer] = None,
2007
+ sb: Optional[IndentedBuffer] = None,
2008
+ ) -> Iterator[None]:
2009
+ if cb is None:
2010
+ cb = lb
2011
+ if disallow_stores := sb is None:
2012
+ sb = IndentedBuffer()
2013
+ loads = self.loads
2014
+ compute = self.compute
2015
+ stores = self.stores
2016
+ cse = self.cse
2017
+ self.loads = lb
2018
+ self.compute = cb
2019
+ self.stores = sb
2020
+ self.cse = cse.scoped_copy()
2021
+ try:
2022
+ yield
2023
+ finally:
2024
+ self.loads = loads
2025
+ self.compute = compute
2026
+ self.stores = stores
2027
+ self.cse = cse
2028
+ if disallow_stores:
2029
+ assert not sb, "unexpected store inside swap_buffers"
2030
+
2031
+ def load(self, name: str, index: sympy.Expr) -> CSEVariable:
2032
+ raise NotImplementedError
2033
+
2034
+ def indirect_load(self, name: str, index: sympy.Expr) -> CSEVariable:
2035
+ """A load the depends on an index we have read"""
2036
+ prior = self.loads
2037
+ try:
2038
+ # put the load in the compute section as it might have deps
2039
+ self.loads = self.compute
2040
+ return self.load(name, index)
2041
+ finally:
2042
+ self.loads = prior
2043
+
2044
+ def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
2045
+ raise NotImplementedError
2046
+
2047
+ def store(
2048
+ self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
2049
+ ) -> None:
2050
+ raise NotImplementedError
2051
+
2052
+ def reduction(
2053
+ self,
2054
+ dtype: torch.dtype,
2055
+ src_dtype: torch.dtype,
2056
+ reduction_type: ReductionType,
2057
+ value: Union[CSEVariable, tuple[CSEVariable, ...]],
2058
+ ) -> Union[CSEVariable, tuple[CSEVariable, ...]]:
2059
+ raise NotImplementedError
2060
+
2061
+ def scan(
2062
+ self,
2063
+ dtypes: tuple[torch.dtype, ...],
2064
+ combine_fn: Callable[
2065
+ [tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...]
2066
+ ],
2067
+ values: tuple[CSEVariable, ...],
2068
+ ) -> tuple[CSEVariable, ...]:
2069
+ raise NotImplementedError
2070
+
2071
+ def sort(
2072
+ self,
2073
+ dtypes: tuple[torch.dtype, ...],
2074
+ values: tuple[CSEVariable, ...],
2075
+ stable: bool,
2076
+ descending: bool,
2077
+ ) -> tuple[CSEVariable, ...]:
2078
+ raise NotImplementedError
2079
+
2080
+ def var_ranges(self) -> dict[sympy.Symbol, sympy.Expr]:
2081
+ raise NotImplementedError
2082
+
2083
+ def bucketize(
2084
+ self,
2085
+ values: CSEVariable,
2086
+ boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
2087
+ boundary_indices: CSEVariable,
2088
+ indexing_dtype: torch.dtype,
2089
+ right: bool,
2090
+ sorter: Optional[tuple[str, sympy.Expr]] = None,
2091
+ sorter_indices: Optional[CSEVariable] = None,
2092
+ ) -> CSEVariable:
2093
+ """
2094
+ See [Note: Inductor bucketize op]
2095
+ """
2096
+ raise NotImplementedError
2097
+
2098
+ @property
2099
+ def assert_function(self) -> str:
2100
+ raise NotImplementedError
2101
+
2102
+ def indirect_assert(
2103
+ self,
2104
+ var: Union[CSEVariable, str],
2105
+ lower: Optional[str],
2106
+ upper: Optional[str],
2107
+ mask: Optional[Union[CSEVariable, str]] = None,
2108
+ ) -> str:
2109
+ if isinstance(var, CSEVariable):
2110
+ var = str(var)
2111
+ assert isinstance(var, str), type(var)
2112
+ assert lower is None or isinstance(lower, str)
2113
+ assert upper is None or isinstance(upper, str)
2114
+ if lower and upper:
2115
+ # The conditions need to be in parens because of Python's operator precedence.
2116
+ # It'd be less error-prone to use and/or/not, which is supported by triton
2117
+ cond = f"({lower} <= {var}) & ({var} < {upper})"
2118
+ cond_print = f"{lower} <= {var} < {upper}"
2119
+ elif lower:
2120
+ cond = f"{lower} <= {var}"
2121
+ cond_print = cond
2122
+ else:
2123
+ assert upper
2124
+ cond = f"{var} < {upper}"
2125
+ cond_print = cond
2126
+
2127
+ if mask:
2128
+ cond = f"({cond}) | ~({mask})"
2129
+
2130
+ return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
2131
+
2132
+ def check_bounds(
2133
+ self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
2134
+ ) -> None:
2135
+ raise NotImplementedError
2136
+
2137
+ def index_to_str(self, index: sympy.Expr) -> str:
2138
+ raise NotImplementedError
2139
+
2140
+ def __enter__(self) -> Self:
2141
+ super().__enter__()
2142
+ assert self.overrides
2143
+ self.exit_stack.enter_context(
2144
+ V.set_ops_handler(CSEProxy(self, self.overrides()))
2145
+ )
2146
+ self.exit_stack.enter_context(V.set_kernel_handler(self))
2147
+ return self
2148
+
2149
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
2150
+ self.remove_kernel_local_buffers()
2151
+ super().__exit__(exc_type, exc_val, exc_tb)
2152
+
2153
+ def remove_kernel_local_buffers(self) -> None:
2154
+ """
2155
+ Any buffers that are both created and have a last use in the
2156
+ same kernel can be removed.
2157
+
2158
+ Note that V.graph.scheduler can be None when codegening triton template
2159
+ kernels.
2160
+ """
2161
+ scheduler = V.graph.scheduler
2162
+ if not scheduler:
2163
+ return
2164
+ fused_node_names = OrderedSet(
2165
+ scheduler.name_to_buf[buf].defining_op_name()
2166
+ for buf in self.store_buffer_names
2167
+ if buf in scheduler.name_to_buf
2168
+ )
2169
+ names_to_remove: OrderedSet[str] = OrderedSet()
2170
+ for name in self.store_buffer_names:
2171
+ if (
2172
+ name not in self.must_keep_buffers
2173
+ and name not in self.args.input_buffers
2174
+ and scheduler.can_buffer_be_removed_through_fusion(
2175
+ name, fused_node_names
2176
+ )
2177
+ ):
2178
+ names_to_remove.add(name)
2179
+
2180
+ for name in names_to_remove:
2181
+ if name in self.args.inplace_buffers:
2182
+ buf = self.args.inplace_buffers[name]
2183
+ if isinstance(buf, RemovedArg):
2184
+ continue
2185
+ remove = all(n in names_to_remove for n in buf.other_names)
2186
+ if remove:
2187
+ self.remove_inplace_buffer(name)
2188
+ self.inplaced_to_remove.add(name)
2189
+ else:
2190
+ self.remove_buffer(name)
2191
+
2192
+ def remove_buffer(self, name: str) -> None:
2193
+ # Assign a special value instead of deleting the entry
2194
+ # because we still rely on output_buffers's length to
2195
+ # generate unique arg name.
2196
+ log.debug("remove_buffer(%r)", name)
2197
+ self.args.output_buffers[name] = REMOVED
2198
+ self.removed_buffers.add(name)
2199
+
2200
+ def remove_inplace_buffer(self, name: str) -> None:
2201
+ log.debug("removing_inplace_buffer(%r)", name)
2202
+ self.args.inplace_buffers[name] = REMOVED
2203
+ self.removed_buffers.add(name)
2204
+
2205
+ def rename_indexing(
2206
+ self, index: Union[list[sympy.Expr], tuple[sympy.Expr, ...], sympy.Expr]
2207
+ ) -> sympy.Expr:
2208
+ # adds the necessary kernel args for index expressions
2209
+ # and renames variables in index expressions to kernel arg names
2210
+ if isinstance(index, (list, tuple)):
2211
+ return [self.rename_indexing(x) for x in index]
2212
+ index = V.graph.sizevars.simplify(index)
2213
+ sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
2214
+ replacements = {
2215
+ x: self.args.size(x)
2216
+ for x in sorted_symbols
2217
+ if symbol_is_type(
2218
+ x,
2219
+ (
2220
+ SymT.UNBACKED_INT,
2221
+ SymT.SIZE,
2222
+ SymT.PRECOMPUTED_SIZE,
2223
+ ),
2224
+ )
2225
+ }
2226
+ return sympy_subs(index, replacements)
2227
+
2228
+ def create_cse_var(self, *args: Any, **kwargs: Any) -> CSEVariable:
2229
+ return CSEVariable(*args, **kwargs)
2230
+
2231
+ def arg_name(self, node: IRNode) -> Optional[str]:
2232
+ """
2233
+ Returns arg name of a given input or output node.
2234
+ """
2235
+ if node is None:
2236
+ return None
2237
+ return self.args.arg_name(node.get_name())
2238
+
2239
+
2240
+ @dataclasses.dataclass
2241
+ class OptimizationContext:
2242
+ key: ClassVar[str] = "opt_ctx"
2243
+
2244
+ dtype: Optional[torch.dtype] = None
2245
+ ops_name: str = ""
2246
+
2247
+
2248
+ @functools.cache
2249
+ def jinja2_env() -> Any:
2250
+ try:
2251
+ import jinja2
2252
+
2253
+ return jinja2.Environment(
2254
+ undefined=jinja2.StrictUndefined,
2255
+ )
2256
+ except ImportError:
2257
+ return None
2258
+
2259
+
2260
+ class KernelTemplate:
2261
+ """
2262
+ Base class for defining kernel templates.
2263
+
2264
+ Children classes: TritonTemplate, CUDATemplate
2265
+ """
2266
+
2267
+ @staticmethod
2268
+ def indent_except_first(
2269
+ source: str, num_indents: int, indents_spacing: int = 4
2270
+ ) -> str:
2271
+ lines = source.splitlines(True)
2272
+ if len(lines) > 1:
2273
+ lines[1:] = [
2274
+ (" " * indents_spacing * num_indents) + line for line in lines[1:]
2275
+ ]
2276
+ return "".join(lines)
2277
+
2278
+ @staticmethod
2279
+ def _template_from_string(source: str) -> Any:
2280
+ env = jinja2_env()
2281
+ if env is None:
2282
+ return None
2283
+ env.filters["indent_except_first"] = KernelTemplate.indent_except_first
2284
+ from jinja2 import TemplateSyntaxError
2285
+
2286
+ try:
2287
+ return env.from_string(source)
2288
+ except TemplateSyntaxError as e:
2289
+
2290
+ class DetailedTemplateSyntaxError(TemplateSyntaxError):
2291
+ def __init__(self, original_error: TemplateSyntaxError) -> None:
2292
+ super().__init__(
2293
+ original_error.message,
2294
+ original_error.lineno,
2295
+ original_error.name,
2296
+ original_error.filename,
2297
+ )
2298
+ self.original_error = original_error
2299
+
2300
+ def __str__(self) -> str:
2301
+ error_info = f"Error in template at line {self.lineno}\n"
2302
+ error_info += f"Error message: {self.message}\n"
2303
+ if hasattr(self.original_error, "source"):
2304
+ lines = self.original_error.source.split("\n")
2305
+ error_info += "Context:\n"
2306
+ start = max(0, self.lineno - 2)
2307
+ end = min(len(lines), self.lineno + 2)
2308
+ for i in range(start, end):
2309
+ if i == self.lineno - 1:
2310
+ error_info += f"{i + 1}: --> {lines[i]}\n"
2311
+ if hasattr(self.original_error, "column"):
2312
+ error_info += (
2313
+ " "
2314
+ + " " * (self.original_error.column - 1)
2315
+ + "^\n"
2316
+ )
2317
+ else:
2318
+ error_info += f"{i + 1}: {lines[i]}\n"
2319
+ return error_info
2320
+
2321
+ raise DetailedTemplateSyntaxError(e) from e
2322
+
2323
+ @staticmethod
2324
+ def _fake_get_dtype(
2325
+ fake_outs: Union[list[Buffer], Buffer],
2326
+ ) -> Callable[[str], torch.dtype]:
2327
+ _get_dtype_real = V.graph.get_dtype
2328
+ if isinstance(fake_outs, (list, tuple)):
2329
+ lookup = {buf.get_name(): buf.get_dtype() for buf in fake_outs}
2330
+ else:
2331
+ lookup = {fake_outs.get_name(): fake_outs.get_dtype()}
2332
+
2333
+ def get_dtype(name: str) -> torch.dtype:
2334
+ result = lookup.get(name)
2335
+ if result is not None:
2336
+ return result
2337
+ return _get_dtype_real(name)
2338
+
2339
+ return get_dtype
2340
+
2341
+ def __init__(self, name: str) -> None:
2342
+ self.name = name
2343
+
2344
+ def maybe_append_choice(
2345
+ self, choices: list[Any], **kwargs: Any
2346
+ ) -> Optional[NotImplementedError]:
2347
+ """
2348
+ Maybe generates a new ChoiceCaller and appends it into existing choices.
2349
+ Returns None if success, otherwise returns the error.
2350
+
2351
+ choices: A list of ChoiceCallers.
2352
+ kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
2353
+ """
2354
+
2355
+ try:
2356
+ choices.append(self.generate(**kwargs))
2357
+ return None
2358
+ except NotImplementedError as e:
2359
+ log.info(
2360
+ "Cannot Append Choice: %s. KernelTemplate type is %s",
2361
+ e,
2362
+ type(self),
2363
+ stack_info=log.getEffectiveLevel() < logging.INFO,
2364
+ )
2365
+ return e
2366
+
2367
+ def generate(self, **kwargs: Any) -> ChoiceCaller:
2368
+ """
2369
+ Generates a ChoiceCaller instance from the given arguments.
2370
+ """
2371
+
2372
+ raise NotImplementedError
2373
+
2374
+
2375
+ class CSEProxy(DefaultHandler):
2376
+ name = "CSEProxy"
2377
+
2378
+ def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]):
2379
+ super().__init__()
2380
+ from ..bounds import ValueRangeAnalysis
2381
+
2382
+ self.vr_analysis = ValueRangeAnalysis()
2383
+ self.kernel = kernel
2384
+ self.parent_handler = parent_handler
2385
+
2386
+ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
2387
+ bounds = self._bound_variable(name, *args, **kwargs)
2388
+
2389
+ value = getattr(self.parent_handler, name)(*args, **kwargs)
2390
+ dtype_handler = DtypePropagationOpsHandler()
2391
+
2392
+ backend = get_current_backend()
2393
+
2394
+ output_dtype = None
2395
+ if name == "masked" and backend == "triton":
2396
+ output_dtype = value.dtype
2397
+ elif name == "masked" and backend == "cpp":
2398
+ output_dtype = V.interpreter.current_node.meta.get(
2399
+ OptimizationContext.key, None
2400
+ ).dtype
2401
+ elif backend in ("triton", "cpp", "mps"):
2402
+ dtype_op = getattr(dtype_handler, name)
2403
+ output_dtype = dtype_op(*args, **kwargs)
2404
+
2405
+ if backend in ("triton", "cpp"):
2406
+ # maybe there are some exceptions on mps?
2407
+ assert output_dtype is not None
2408
+
2409
+ output_idx = 0
2410
+
2411
+ def do_cse(v: str) -> CSEVariable:
2412
+ # we tree_map over the output, so we need to fetch corresponding dtype
2413
+ nonlocal output_idx
2414
+ var_dtype: Optional[torch.dtype] = (
2415
+ output_dtype[output_idx]
2416
+ if isinstance(output_dtype, (list, tuple))
2417
+ else output_dtype
2418
+ )
2419
+ output_idx += 1
2420
+
2421
+ # some cpp op implementations don't set the dtype
2422
+ if backend == "cpp" and isinstance(v, CSEVariable) and v.dtype is None:
2423
+ v.dtype = var_dtype
2424
+
2425
+ csevar = V.kernel.cse.generate(
2426
+ V.kernel.compute,
2427
+ v,
2428
+ bounds=bounds,
2429
+ dtype=output_dtype,
2430
+ )
2431
+
2432
+ csevar.update_on_args(name, args, kwargs)
2433
+
2434
+ if (
2435
+ config.test_configs.runtime_triton_dtype_assert
2436
+ or config.test_configs.static_cpp_dtype_assert
2437
+ ):
2438
+ assert var_dtype is not None
2439
+ check_dtype(V.kernel.compute, csevar, var_dtype)
2440
+ return csevar
2441
+
2442
+ return pytree.tree_map(do_cse, value)
2443
+
2444
+ def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[Any]:
2445
+ """
2446
+ If the variable comes from an FX node, we forward the bound we have already computed
2447
+ Else, if the variable when codegen'ing another op, we try to compute its bounds
2448
+ """
2449
+ from ..bounds import ValueRangeAnalysis
2450
+ from ..select_algorithm import TritonTemplateKernel
2451
+ from .cuda.cuda_kernel import CUDATemplateKernel
2452
+
2453
+ if isinstance(V.kernel, TritonTemplateKernel):
2454
+ return ValueRanges.unknown()
2455
+
2456
+ if isinstance(V.kernel, CUDATemplateKernel):
2457
+ return ValueRanges.unknown()
2458
+
2459
+ fx_node = V.interpreter.current_node
2460
+ if fx_node.target == name and self.kernel.node_to_bounds is not None:
2461
+ assert isinstance(self.kernel.node_to_bounds, dict), type(
2462
+ self.kernel.node_to_bounds
2463
+ )
2464
+ return self.kernel.node_to_bounds.get(fx_node, ValueRanges.unknown())
2465
+ elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
2466
+ # These create lots of inner strings. We would need to compute the bounds at the ops
2467
+ # We will also likely not get much from computing VRs on these nodes
2468
+ if any(s in fx_node.target for s in ("set_indirect", "reduction", "scan")):
2469
+ return ValueRanges.unknown()
2470
+
2471
+ # We assume that the inputs come from `ops.` and are not strings. If you want to generate
2472
+ # intermediary strings, wrap them in CSE variables with properly initialised bounds.
2473
+
2474
+ # If there is no FX bound but we know how to compute one we do so
2475
+ assert not kwargs
2476
+
2477
+ def arg_to_bound(x: Any) -> Any:
2478
+ if isinstance(x, CSEVariable):
2479
+ return x.bounds
2480
+ elif isinstance(x, sympy.Expr):
2481
+ return bound_sympy(x)
2482
+ else:
2483
+ return x
2484
+
2485
+ arg_bounds = list(map(arg_to_bound, args))
2486
+ return getattr(self.vr_analysis, name)(*arg_bounds)
2487
+ return ValueRanges.unknown()
2488
+
2489
+ def indirect_indexing(
2490
+ self,
2491
+ var: CSEVariable,
2492
+ size: Union[sympy.Expr, int],
2493
+ check: bool = True,
2494
+ wrap_neg: bool = True,
2495
+ ) -> sympy.Symbol:
2496
+ if isinstance(size, int):
2497
+ size = sympy.Integer(size)
2498
+ assert isinstance(size, sympy.Expr), (type(size), size)
2499
+ # Skip CSE since this doesn't return an expression
2500
+
2501
+ if var.bounds.lower < 0:
2502
+ if wrap_neg:
2503
+ stm = ops.add(var, ops.index_expr(size, torch.long))
2504
+ # Mixed negative and non-negative
2505
+ if var.bounds.upper >= 0:
2506
+ lt = ops.lt(var, 0)
2507
+ stm = ops.where(lt, stm, var)
2508
+ else:
2509
+ stm = var
2510
+
2511
+ # Propagate bounds as we know how to compute them properly
2512
+ new_bounds = ValueRanges.unknown()
2513
+ if var.bounds != ValueRanges.unknown() and isinstance(size, sympy.Number):
2514
+ # Take the negative part of the bound and add size to it
2515
+ # Then take union of that and the positive part
2516
+ # This is a tighter bound than that of a generic ops.where, as we have info on the cond
2517
+ neg_bounds = var.bounds & ValueRanges(-int_oo, -1)
2518
+ new_bounds = ValueRanges(
2519
+ neg_bounds.lower + size, neg_bounds.upper + size
2520
+ )
2521
+ # We don't have a good way of representing the empty range
2522
+ if var.bounds.upper >= 0:
2523
+ pos = var.bounds & ValueRanges(0, int_oo)
2524
+ new_bounds = new_bounds | pos
2525
+
2526
+ var = self.kernel.cse.generate(self.kernel.compute, stm, bounds=new_bounds)
2527
+
2528
+ sympy_var = self.parent_handler.indirect_indexing(var, size, check)
2529
+ if generate_assert(check):
2530
+ assert_lower = not (var.bounds.lower >= 0)
2531
+ # value ranges cannot x < s when x and s are symbols
2532
+ assert_upper = not isinstance(size, sympy.Number) or not (
2533
+ var.bounds.upper < size
2534
+ )
2535
+ self.kernel.check_bounds(sympy_var, size, assert_lower, assert_upper)
2536
+ return sympy_var
2537
+
2538
+ def check_bounds(
2539
+ self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
2540
+ ) -> None:
2541
+ return self.kernel.check_bounds(expr, size, lower, upper)
2542
+
2543
+ def load(self, name: str, index: sympy.Expr) -> CSEVariable:
2544
+ if name in self.kernel.cse.invalidated_stores:
2545
+ # A load from an invalidated store requires us to
2546
+ # keep the actual buffer around
2547
+ V.kernel.must_keep_buffers.add(name)
2548
+ if free_symbol_is_type(index, SymT.TMP):
2549
+ return self.kernel.indirect_load(name, index)
2550
+ store_cache = self.kernel.cse.store_cache
2551
+ if name in store_cache:
2552
+ return store_cache[name]
2553
+ out = self.kernel.load(name, index)
2554
+ # count load that is not in the store_cache, and also not in the
2555
+ # cse cache.
2556
+ if out.use_count == 1:
2557
+ self.kernel.num_load += 1
2558
+ return out
2559
+
2560
+ def _update_store_cache(self, name: str, value: CSEVariable) -> None:
2561
+ self.kernel.cse.store_cache[name] = value
2562
+ if self.kernel.current_node and name in V.graph.name_to_buffer:
2563
+ buf = self.kernel.current_node.get_output(name)
2564
+ for other_name in buf.get_mutations():
2565
+ self.kernel.cse.store_cache[other_name] = value
2566
+
2567
+ def store(
2568
+ self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
2569
+ ) -> None:
2570
+ self.kernel.store_buffer_names.add(name)
2571
+ if mode is None:
2572
+ self._update_store_cache(name, value)
2573
+ if name not in V.graph.removed_buffers:
2574
+ self.kernel.store(name, index, value, mode=mode)
2575
+
2576
+ def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
2577
+ self.kernel.store_buffer_names.add(name)
2578
+ self._update_store_cache(name, value)
2579
+
2580
+ if name not in V.graph.removed_buffers:
2581
+ return self.kernel.store_reduction(name, index, value)
2582
+
2583
+ def reduction(
2584
+ self,
2585
+ dtype: torch.dtype,
2586
+ src_dtype: torch.dtype,
2587
+ reduction_type: ReductionType,
2588
+ value: Union[CSEVariable, tuple[CSEVariable, ...]],
2589
+ ) -> Union[CSEVariable, tuple[CSEVariable, ...]]:
2590
+ self.kernel.num_reduction += 1
2591
+ return self.kernel.reduction(dtype, src_dtype, reduction_type, value)
2592
+
2593
+ def scan(
2594
+ self,
2595
+ dtypes: tuple[torch.dtype, ...],
2596
+ combine_fn: Callable[
2597
+ [tuple[CSEVariable, ...], tuple[CSEVariable, ...]],
2598
+ tuple[CSEVariable, ...],
2599
+ ],
2600
+ values: tuple[CSEVariable, ...],
2601
+ ) -> tuple[CSEVariable, ...]:
2602
+ return self.kernel.scan(dtypes, combine_fn, values)
2603
+
2604
+ def sort(
2605
+ self,
2606
+ dtypes: tuple[torch.dtype, ...],
2607
+ values: tuple[CSEVariable, ...],
2608
+ stable: bool,
2609
+ descending: bool,
2610
+ ) -> tuple[CSEVariable, ...]:
2611
+ return self.kernel.sort(dtypes, values, stable, descending)
2612
+
2613
+ def bucketize(
2614
+ self,
2615
+ values: CSEVariable,
2616
+ boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
2617
+ boundary_indices: CSEVariable,
2618
+ indexing_dtype: torch.dtype,
2619
+ right: bool,
2620
+ sorter: Optional[tuple[str, sympy.Expr]] = None,
2621
+ sorter_indices: Optional[CSEVariable] = None,
2622
+ ) -> CSEVariable:
2623
+ """
2624
+ [Note: Inductor bucketize op]
2625
+
2626
+ Inputs:
2627
+ -------
2628
+ values: the values to be bucketized.
2629
+ boundaries: a tuple containing
2630
+ (a) the name of the boundaries tensor (which must be sorted, unless
2631
+ the sorting tensor is present),
2632
+ (b) the length of the tensor in the last dimension (i.e. the length of
2633
+ one set of boundaries),
2634
+ (c) the number of elements in the underlying storage (i.e. the length
2635
+ of the flattened tensor, ignoring striding), and
2636
+ (d) the stride of the tensor in the last dimension.
2637
+ boundary_indices: indices into a flattened version of the boundaries
2638
+ tensor, of the same size and shape as "values". Each index points to
2639
+ the first element in the set of boundaries to be used for the
2640
+ corresponding value.
2641
+ indexing_dtype: the dtype to use when indexing into the boundaries
2642
+ tensor. This must be int64 or int32. This additionally specifies the
2643
+ dtype of the return value.
2644
+ right: see "Details" below.
2645
+ sorter: an optional tuple containing
2646
+ (a) the name of an optional sorting tensor, used to access unsorted
2647
+ boundaries without reordering the boundaries tensor, and
2648
+ (b) the stride of the tensor in the last dimension.
2649
+ The values in the sorting tensor are used as indices into the *last*
2650
+ dimension of the boundaries tensor, with all other indices matching.
2651
+ The size of the sorting and boundaries tensors must be equivalent.
2652
+ sorter_indices: must be present if the sorting array is present; see
2653
+ "boundary_indices" for the equivalent definition for the boundaries
2654
+ tensor.
2655
+
2656
+ Output:
2657
+ -------
2658
+ The buckets each value belongs in, within a given set of boundaries. 0
2659
+ indicates a position before the first boundary, and len(boundaries_set)
2660
+ represents a position after the last boundary.
2661
+
2662
+ Details:
2663
+ --------
2664
+ Given a value and a set of boundaries, calculate the bucket that each
2665
+ value belongs to. This works differently in 1-D and N-D cases.
2666
+
2667
+ for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [0, 4, 4, 8], right=True
2668
+ return = [[ 0, 1, 1, 1], [1, 3, 3, 4]].
2669
+
2670
+ for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [[0, 4], [4, 8]], right=True
2671
+ return = [[ 0, 1, 1, 1], [0, 1, 1, 2]]
2672
+
2673
+ Note that in the N-D boundaries case, the shape of "values" and
2674
+ "boundaries" must match in every dimension _except_ the last.
2675
+
2676
+ When right == False, bucket i refers to range (boundaries[i], boundaries[i+1]].
2677
+ When right == True, bucket i refers to range [boundaries[i], boundaries[i+1]).
2678
+
2679
+ Boundaries must be non-decreasing, or a sorter must be provided which
2680
+ would re-index offsets in a non-decreasing order (e.g. the second output
2681
+ of torch.sort(offsets)). Otherwise, the result is undefined.
2682
+ """
2683
+ return self.kernel.bucketize(
2684
+ values,
2685
+ boundaries,
2686
+ boundary_indices,
2687
+ indexing_dtype,
2688
+ right,
2689
+ sorter,
2690
+ sorter_indices,
2691
+ )
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_bmm_template.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import contextlib
3
+ import itertools
4
+ from typing import Any, Callable, Optional
5
+ from unittest.mock import patch
6
+
7
+ import sympy
8
+
9
+ from .. import ir
10
+ from ..select_algorithm import PartialRender
11
+ from ..virtualized import V
12
+ from .common import ArgName
13
+ from .cpp_gemm_template import CppGemmTemplate, GEMM_TEMPLATE
14
+ from .cpp_micro_gemm import LayoutType
15
+ from .cpp_template_kernel import CppTemplateKernel
16
+ from .cpp_utils import DTYPE_TO_CPP, GemmBlocking
17
+
18
+
19
+ # We pass all sizevars present in BY to the GEMM templates so variables are not renamed in the BMM definition
20
+ GEMM_SINGLE_THREAD_MM_STUB = r"""
21
+ {{kernel.def_kernel(
22
+ inputs={"X": X, "W": W},
23
+ outputs={"Y": Y_2d},
24
+ aliases=aliases,
25
+ function_name=kernel_name+"_single_thread_mm",
26
+ extra_sizevars=BY_sizevars + [b_index],
27
+ placeholder="<SINGLE_THREAD_MM_DEF_FOR_BMM>")}}"""
28
+
29
+ GEMM_THREADED_MM_STUB = r"""
30
+ {{kernel.def_kernel(
31
+ inputs={"X": X, "W": W},
32
+ outputs={"Y": Y_2d},
33
+ aliases=aliases,
34
+ function_name=kernel_name+"_threaded_mm",
35
+ extra_sizevars=BY_sizevars + [b_index],
36
+ placeholder="<THREADED_MM_DEF_FOR_BMM>")}}"""
37
+
38
+ BMM_TEMPLATE = r"""
39
+ {{ template.codegen_microkernel_def() }}
40
+ {{ template.codegen_single_thread_gemm() }}
41
+ {{ template.codegen_multi_thread_gemm() }}
42
+
43
+ extern "C"
44
+ {{kernel.def_kernel(inputs={"X": BX, "W": BW}, outputs={"Y": BY}, aliases=aliases)}}
45
+ {
46
+ const int64_t B = {{kernel.size(BY_2d, 0)}};
47
+ {%- if num_threads > 1 %}
48
+ constexpr int64_t num_threads = {{num_threads}};
49
+ int64_t B_single_thread_block = (B / num_threads) * num_threads;
50
+
51
+ #pragma omp parallel for num_threads({{num_threads}})
52
+ {%- else %}
53
+ int64_t B_single_thread_block = B;
54
+ {%- endif %}
55
+ for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) {
56
+ {{template.get_gemm_function_call(
57
+ kernel,
58
+ kernel_name+"_single_thread_mm",
59
+ "<SINGLE_THREAD_CALL_FOR_BMM>",
60
+ b_index="b_start",
61
+ )}}
62
+ }
63
+ for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) {
64
+ {{template.get_gemm_function_call(
65
+ kernel,
66
+ kernel_name+"_threaded_mm",
67
+ "<THREADED_MM_CALL_FOR_BMM>",
68
+ b_index="b_start",
69
+ )}}
70
+ }
71
+ }
72
+ """
73
+
74
+
75
+ class CppBmmTemplate(CppGemmTemplate):
76
+ def __init__(
77
+ self,
78
+ input_nodes,
79
+ layout: ir.Layout,
80
+ num_threads: int,
81
+ register_blocking: GemmBlocking,
82
+ beta=1,
83
+ alpha=1,
84
+ has_bias=False,
85
+ epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
86
+ should_block_weights: bool = False,
87
+ name="bmm",
88
+ ):
89
+ """
90
+ In order to simplify the implementation and increase code reuse, the BMM template implements
91
+ two versions of the GEMM kernel: a single-threaded version and a multi-threaded version.
92
+ GEMM kernels are called in a loop over the batch dimension, with single-threaded GEMM calls
93
+ for all but the last (B % num_threads), which are handled by the multi-threaded GEMM kernel.
94
+
95
+ We use an extra sizevar `b_index` to index the batch dimension, which we pass into the GEMM
96
+ template as a sympy.Symbol. This allows us to slice the 3D batch tensors in the GEMM template
97
+ without any changes to the GEMM template itself.
98
+ """
99
+ super().__init__(
100
+ input_nodes,
101
+ layout,
102
+ num_threads,
103
+ register_blocking,
104
+ beta=beta,
105
+ alpha=alpha,
106
+ has_bias=has_bias,
107
+ epilogue_creator=epilogue_creator,
108
+ should_block_weights=should_block_weights,
109
+ name=name,
110
+ )
111
+ self.b_index = sympy.Symbol("s_b_index", integer=True, nonnegative=True)
112
+
113
+ @staticmethod
114
+ def get_padded_size(n, block_n, k, should_block_weight):
115
+ if should_block_weight:
116
+ # Tensor is constant or not contiguous, so we will pad and block
117
+ new_size, padded_n = CppGemmTemplate.get_padded_size(
118
+ n, block_n, k, should_block_weight
119
+ )
120
+ # Add the new batch dimension
121
+ new_size.insert(0, -1)
122
+ return new_size, padded_n
123
+ else:
124
+ new_size = [-1, k, n]
125
+ return new_size, n
126
+
127
+ @staticmethod
128
+ def check_if_block_weight(W, micro_gemm):
129
+ assert isinstance(W, ir.IRNode)
130
+ _, n = W.get_size()[-2:]
131
+ result = (
132
+ not W.get_layout().is_contiguous()
133
+ or W.get_name() in V.graph.constants
134
+ or (
135
+ n % micro_gemm.register_blocking.block_n != 0
136
+ and micro_gemm.get_b_layout != LayoutType.NORMAL
137
+ )
138
+ )
139
+ return result
140
+
141
+ def get_gemm_function_call(
142
+ self,
143
+ kernel: CppTemplateKernel,
144
+ function_name: str,
145
+ placeholder: str,
146
+ b_index: str,
147
+ ) -> str:
148
+ """
149
+ Similar to 'def_kernel' in cpp_template_kernel, but instead of generating a function definition,
150
+ generate a function call for the GEMM kernel.
151
+ Args:
152
+ placeholder: The string to replace the function call with
153
+ b_index: The index for slicing the 3D batch tensors
154
+ """
155
+
156
+ def hook():
157
+ arg_defs, call_args, _, _ = kernel.args.python_argdefs()
158
+ for i, buf in enumerate(call_args):
159
+ if buf == self.b_index:
160
+ arg_defs[i] = ArgName(b_index)
161
+ call = f"{function_name}({', '.join(x.full_name() for x in arg_defs)});"
162
+ return call
163
+
164
+ assert placeholder not in kernel.render_hooks
165
+ kernel.render_hooks[placeholder] = hook
166
+ return placeholder
167
+
168
+ def get_default_reindexers(self, epilogue_nodes):
169
+ def reindexer(args):
170
+ # if epilogue nodes exist, they have 3D ranges but args are 2D, so add 0 index
171
+ return [self.b_index] + args
172
+
173
+ return [reindexer] * len(epilogue_nodes)
174
+
175
+ def get_options(
176
+ self,
177
+ kernel: CppTemplateKernel,
178
+ template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
179
+ flag_template_buffer_has_other_users: Optional[bool] = None,
180
+ epilogue_nodes: Optional[list[ir.IRNode]] = None,
181
+ **kwargs,
182
+ ) -> dict[str, Any]:
183
+ options = super().get_options(
184
+ kernel=kernel,
185
+ template_buffer_node=template_buffer_node,
186
+ flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
187
+ epilogue_nodes=epilogue_nodes,
188
+ **kwargs,
189
+ )
190
+
191
+ BX, BW, BY = options["X"], options["W"], options["Y"]
192
+ options["BX"], options["BW"], options["BY"] = BX, BW, BY
193
+ options["BY_2d"] = options["Y_2d"]
194
+ for kword in ["X", "W", "GemmOut", "Y_2d"]:
195
+ options[kword] = kernel.select(options[kword], 0, self.b_index)
196
+ for kword in ["X", "W", "Y_2d"]:
197
+ options[kword + "_dtype"] = DTYPE_TO_CPP[options[kword].dtype]
198
+ options["b_index"] = self.b_index
199
+ options["BY_sizevars"] = [
200
+ s
201
+ for sym in itertools.chain(BY.get_size(), BY.get_stride())
202
+ if isinstance(sym, sympy.Expr)
203
+ for s in sym.free_symbols
204
+ ]
205
+ options["kernel_name"] = kernel.kernel_name
206
+
207
+ return options
208
+
209
+ def render( # type: ignore[override, return]
210
+ self,
211
+ kernel: CppTemplateKernel,
212
+ template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
213
+ flag_template_buffer_has_other_users: Optional[bool] = None,
214
+ epilogue_nodes: Optional[list[ir.IRNode]] = None,
215
+ **kwargs,
216
+ ) -> str:
217
+ options = self.get_options(
218
+ kernel=kernel,
219
+ template_buffer_node=template_buffer_node,
220
+ flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
221
+ epilogue_nodes=epilogue_nodes,
222
+ **kwargs,
223
+ )
224
+ self.render_options = options
225
+
226
+ with contextlib.ExitStack() as stack:
227
+ for buf in options["fake_buffers"]:
228
+ stack.enter_context(
229
+ patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
230
+ )
231
+ result = self._template_from_string(BMM_TEMPLATE).render(**options)
232
+
233
+ # Finalize the function definitions for the gemm routines
234
+ sub_mm_hooks = {
235
+ name: hook
236
+ for name, hook in kernel.render_hooks.items()
237
+ if "FOR_BMM" in name
238
+ }
239
+ result = PartialRender(result, sub_mm_hooks).finalize_all()
240
+ for name in sub_mm_hooks:
241
+ del kernel.render_hooks[name]
242
+ del kernel.args.sizevars[options["b_index"]]
243
+ return result
244
+
245
+ def codegen_single_thread_gemm(self):
246
+ stub = self._template_from_string(GEMM_SINGLE_THREAD_MM_STUB).render(
247
+ self.render_options
248
+ )
249
+ return stub + self._template_from_string(GEMM_TEMPLATE).render(
250
+ {**self.render_options, "num_threads": 1}
251
+ )
252
+
253
+ def codegen_multi_thread_gemm(self):
254
+ stub = self._template_from_string(GEMM_THREADED_MM_STUB).render(
255
+ self.render_options
256
+ )
257
+ return stub + self._template_from_string(GEMM_TEMPLATE).render(
258
+ self.render_options
259
+ )
260
+
261
+ def codegen_gemm_stub_def(self):
262
+ return ""
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_flex_attention_template.py ADDED
@@ -0,0 +1,1081 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import contextlib
3
+ import logging
4
+ import re
5
+ from typing import Optional
6
+ from unittest.mock import patch
7
+
8
+ import sympy
9
+
10
+ import torch
11
+ import torch.utils
12
+
13
+ from ...utils._ordered_set import OrderedSet
14
+ from .. import ir
15
+ from ..ir import TensorBox
16
+ from ..select_algorithm import DataProcessorTemplateWrapper
17
+ from ..utils import parallel_num_threads
18
+ from ..virtualized import V
19
+ from .cpp_template import CppTemplate
20
+ from .cpp_utils import GemmBlocking
21
+
22
+
23
+ log = logging.getLogger(__name__)
24
+
25
+ # TODO: reuse cpp codegen to generate below pointwise/reduction kernels
26
+ SOFTMAX_FUSIONS = r"""
27
+ // 1) out = exp(a - val)
28
+ // 2) val = sum(out)
29
+ template <typename T1, typename T2>
30
+ inline void {{kernel_name}}_exp_reduce_sum_fusion_kernel(
31
+ T1* a,
32
+ const int& size,
33
+ T2* out,
34
+ T1& val) {
35
+ auto vec_size = at::vec::Vectorized<T1>::size();
36
+ auto vec_max = at::vec::Vectorized<T1>(val);
37
+ T1 tmp_sum = 0;
38
+ auto vec_tmp_sum = at::vec::Vectorized<T1>(tmp_sum);
39
+ for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
40
+ auto tmp0 = at::vec::Vectorized<T1>::loadu(a + i);
41
+ auto tmp1 = tmp0 - vec_max;
42
+ auto tmp2 = tmp1.exp_u20();
43
+ vec_tmp_sum += tmp2;
44
+ at::native::_store(out + i, tmp2);
45
+ }
46
+ tmp_sum = at::vec::vec_reduce_all<T1>(
47
+ [](at::vec::Vectorized<T1>& x, at::vec::Vectorized<T1>& y) {
48
+ return x + y;
49
+ },
50
+ vec_tmp_sum);
51
+ for (long i = vec_size * (size / vec_size); i < size; i++) {
52
+ auto tmp0 = a[i];
53
+ auto tmp1 = tmp0 - val;
54
+ auto tmp2 = exp(tmp1);
55
+ tmp_sum += tmp2;
56
+ out[i] = tmp2;
57
+ }
58
+ val = tmp_sum;
59
+ }
60
+
61
+ // 1) out = a * scale
62
+ // 2) max = max(out)
63
+ template <typename scalar_t>
64
+ inline void {{kernel_name}}_mul_reduce_max_fusion_kernel(
65
+ const scalar_t* a,
66
+ const scalar_t& scale,
67
+ const int& size,
68
+ scalar_t* out,
69
+ scalar_t& max) {
70
+ auto vec_size = at::vec::Vectorized<scalar_t>::size();
71
+ auto vec_scale = at::vec::Vectorized<scalar_t>(scale);
72
+ scalar_t tmp_max = -std::numeric_limits<scalar_t>::infinity();
73
+ auto vec_tmp_max = at::vec::Vectorized<scalar_t>(tmp_max);
74
+ for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
75
+ auto tmp0 = at::vec::Vectorized<scalar_t>::loadu(a + i);
76
+ auto tmp1 = tmp0 * vec_scale;
77
+ vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1);
78
+ at::native::_store(out + i, tmp1);
79
+ }
80
+ for (long i = vec_size * (size / vec_size); i < size; i++) {
81
+ auto tmp0 = a[i];
82
+ auto tmp1 = tmp0 * scale;
83
+ tmp_max = std::max(tmp_max, tmp1);
84
+ out[i] = tmp1;
85
+ }
86
+ max = std::max(
87
+ tmp_max,
88
+ at::vec::vec_reduce_all<scalar_t>(
89
+ [](at::vec::Vectorized<scalar_t>& x, at::vec::Vectorized<scalar_t>& y) {
90
+ return at::vec::maximum(x, y);
91
+ },
92
+ vec_tmp_max));
93
+ }
94
+
95
+ template <typename scalar_t>
96
+ static inline scalar_t* {{kernel_name}}_conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) {
97
+ TORCH_CHECK(ptr2 == nullptr);
98
+ return ptr;
99
+ }
100
+
101
+ template <typename scalar_t,
102
+ typename std::enable_if_t<c10::is_reduced_floating_point_v<scalar_t>, int> = 0>
103
+ static inline scalar_t* {{kernel_name}}_conditional_data_ptr(float* ptr, scalar_t* ptr2) {
104
+ return ptr2;
105
+ }
106
+
107
+ template <typename scalar_t>
108
+ inline void {{kernel_name}}_fill_stub(scalar_t* data, scalar_t val, int64_t size) {
109
+ using Vec = at::vec::Vectorized<scalar_t>;
110
+ Vec data_vec = Vec(val);
111
+ int64_t d = 0;
112
+ for (; d < size - (size % Vec::size()); d += Vec::size()) {
113
+ data_vec.store(data + d);
114
+ }
115
+ #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
116
+ # pragma unroll
117
+ #endif
118
+ for (; d < size; d++) {
119
+ data[d] = val;
120
+ }
121
+ }
122
+
123
+ // out = a * scale
124
+ template <typename scalar_t>
125
+ inline void {{kernel_name}}_mul_scale_kernel(
126
+ scalar_t* a,
127
+ scalar_t scale,
128
+ int64_t size) {
129
+ auto vec_size = at::vec::Vectorized<scalar_t>::size();
130
+ auto vec_scale = at::vec::Vectorized<scalar_t>(scale);
131
+ for (int64_t i = 0; i < vec_size * (size / vec_size); i += vec_size) {
132
+ auto tmp0 = at::vec::Vectorized<scalar_t>::loadu(a + i);
133
+ auto tmp1 = tmp0 * vec_scale;
134
+ at::native::_store(a + i, tmp1);
135
+ }
136
+ for (int64_t i = vec_size * (size / vec_size); i < size; i++) {
137
+ auto tmp0 = a[i];
138
+ auto tmp1 = tmp0 * scale;
139
+ a[i] = tmp1;
140
+ }
141
+ }
142
+
143
+ """
144
+
145
+ BRGEMM_PACK_FUNCTIONS = r"""
146
+ template <typename scalar_t>
147
+ inline void {{kernel_name}}_copy_value_with_pad(
148
+ const scalar_t* value_ptr,
149
+ scalar_t* dst_ptr,
150
+ int64_t rows,
151
+ int64_t cols,
152
+ int64_t prows,
153
+ int64_t pcols,
154
+ int64_t ldi) {
155
+ auto vec_size = at::vec::Vectorized<scalar_t>::size();
156
+ int64_t i = 0;
157
+ for (; i < rows; i++) {
158
+ int64_t j = 0;
159
+ for (; j < cols - (cols % vec_size); j += vec_size) {
160
+ auto vec_v =
161
+ at::vec::Vectorized<scalar_t>::loadu(value_ptr + i * ldi + j);
162
+ vec_v.store(dst_ptr + i * pcols + j);
163
+ }
164
+
165
+ if (j < cols) {
166
+ auto vec_v = at::vec::Vectorized<scalar_t>::loadu(
167
+ value_ptr + i * ldi + j, cols - j);
168
+ vec_v.store(dst_ptr + i * pcols + j, cols - j);
169
+ }
170
+
171
+ // col padding
172
+ auto psize = pcols - cols;
173
+ if (psize > 0) {
174
+ auto zero_vec = at::vec::Vectorized<scalar_t>(0);
175
+ int64_t pj = 0;
176
+ for (; pj < psize - (psize % vec_size); pj += vec_size) {
177
+ zero_vec.store(dst_ptr + i * pcols + cols + pj);
178
+ }
179
+ if (pj < psize) {
180
+ zero_vec.store(dst_ptr + i * pcols + cols + pj, psize - pj);
181
+ }
182
+ }
183
+ }
184
+ // row padding
185
+ for (; i < prows; i++) {
186
+ auto zero_vec = at::vec::Vectorized<scalar_t>(0);
187
+ int64_t j = 0;
188
+ for (; j < pcols - (pcols % vec_size); j += vec_size) {
189
+ zero_vec.store(dst_ptr + i * pcols + j);
190
+ }
191
+ if (j < pcols) {
192
+ zero_vec.store(dst_ptr + i * pcols + j, pcols - j);
193
+ }
194
+
195
+ }
196
+ }
197
+ """
198
+
199
+ MICRO_GEMM_TEMPLATE = r"""
200
+ GEMM_DEFINE
201
+ """
202
+
203
+ ALLOCATE_BUFFER = r"""
204
+ int64_t {{buffer_name}}_dtype_itemsize = c10::is_reduced_floating_point_v<{{buffer_dtype}}> ? 2 : 4;
205
+ auto& {{buffer_name}}_allocator = *at::getCPUAllocator();
206
+ auto {{buffer_name}}_work_data = {{buffer_name}}_allocator.allocate({{buffer_size}}*{{buffer_name}}_dtype_itemsize);
207
+ void* {{buffer_name}}_data_ptr = {{buffer_name}}_work_data.get();
208
+ {{buffer_dtype}}* {{buffer_name}} = ({{buffer_dtype}}*){{buffer_name}}_data_ptr;
209
+ """
210
+
211
+ FLEX_ATTENTION_TEMPLATE = r"""
212
+ {{template.header().getvalue()}}
213
+ #include <ATen/native/cpu/utils.h>
214
+ #include <ATen/native/CPUBlas.h>
215
+ #include <ATen/Context.h>
216
+ {{template.codegen_micro_gemm(kernel.kernel_name)}}
217
+ {{template.codegen_softmax_fusion(kernel.kernel_name)}}
218
+ {{template.codegen_brgemm_pack_function(kernel.kernel_name)}}
219
+ {%- set kernel_args = {"query": query, "key": key, "value": value,
220
+ "kv_num_blocks": kv_num_blocks, "kv_indices": kv_indices,
221
+ "full_kv_num_blocks": full_kv_num_blocks, "full_kv_indices": full_kv_indices } %}
222
+ {%- set kernel_args = template.update_kernel_args(kernel_args) %}
223
+
224
+ extern "C"
225
+ {{kernel.def_kernel(inputs=kernel_args, outputs={"output": output}, extra_sizevars=template.extra_sizevars)}}
226
+ {
227
+ {{ kernel.maybe_codegen_profile() }}
228
+ int64_t qBlockSize = {{qBlockSize}};
229
+ int64_t kvBlockSize = {{kvBlockSize}};
230
+ int64_t num_thread = {{num_thread}};
231
+
232
+ // dtypes of kernel and internal buffers
233
+ using scalar_t = {{kernel.dtype(query)}};
234
+ constexpr bool is_reduced_type = c10::is_reduced_floating_point_v<scalar_t>;
235
+ using accum_t = at::opmath_type<{{kernel.dtype(query)}}>;
236
+ using Vec = at::vec::Vectorized<accum_t>;
237
+ accum_t scaling_factor = {{scale}};
238
+ int64_t batchSize = {{kernel.size(query, 0)}};
239
+ int64_t qSize = {{kernel.size(query, 1)}};
240
+ int64_t num_head = {{kernel.size(query, 2)}};
241
+ int64_t headSize = {{kernel.size(query, 3)}};
242
+ int64_t batchSize_k = {{kernel.size(key, 0)}};
243
+ int64_t num_head_k = {{kernel.size(key, 2)}};
244
+ int64_t headSize_v = {{kernel.size(value, 3)}};
245
+ bool is_broadcast_bs_kv = batchSize != batchSize_k;
246
+ bool is_broadcast_head_kv = num_head != num_head_k;
247
+ int64_t gqa_shards = num_head / num_head_k;
248
+ int64_t bs_shards = batchSize / batchSize_k;
249
+
250
+ int64_t batchSize_kvi = {{kernel.size(kv_indices, 0)}};
251
+ int64_t num_head_kvi = {{kernel.size(kv_indices, 1)}};
252
+ int64_t block_num_kvi = {{kernel.size(kv_indices, 3)}};
253
+ bool is_broadcast_bs_kvi = batchSize != batchSize_kvi;
254
+ bool is_broadcast_head_kvi = num_head != num_head_kvi;
255
+ int64_t gqa_shards_kvi = num_head / num_head_kvi;
256
+ int64_t bs_shards_kvi = batchSize / batchSize_kvi;
257
+
258
+ int64_t kviStrideB = {{kernel.stride(kv_indices, 0)}};
259
+ int64_t kviStrideH = {{kernel.stride(kv_indices, 1)}};
260
+ int64_t kviStrideQ = {{kernel.stride(kv_indices, 2)}};
261
+
262
+ int64_t num_kviStrideB = {{kernel.stride(kv_num_blocks, 0)}};
263
+ int64_t num_kviStrideH = {{kernel.stride(kv_num_blocks, 1)}};
264
+
265
+ {%- if has_full_kv_block %}
266
+ int64_t full_kviStrideB = {{kernel.stride(full_kv_indices, 0)}};
267
+ int64_t full_kviStrideH = {{kernel.stride(full_kv_indices, 1)}};
268
+ int64_t full_kviStrideQ = {{kernel.stride(full_kv_indices, 2)}};
269
+
270
+ int64_t full_num_kviStrideB = {{kernel.stride(full_kv_num_blocks, 0)}};
271
+ int64_t full_num_kviStrideH = {{kernel.stride(full_kv_num_blocks, 1)}};
272
+ auto full_kv_indices_data = full_kv_indices;
273
+ auto full_kv_num_blocks_data = full_kv_num_blocks;
274
+ {%- endif %}
275
+
276
+ auto kv_num_blocks_data = kv_num_blocks;
277
+ auto kv_indices_data = kv_indices;
278
+
279
+ // Strides
280
+ int64_t qStrideB = {{kernel.stride(query, 0)}};
281
+ int64_t qStrideM = {{kernel.stride(query, 1)}};
282
+ int64_t qStrideH = {{kernel.stride(query, 2)}};
283
+ int64_t kStrideB = {{kernel.stride(key, 0)}};
284
+ int64_t kStrideN = {{kernel.stride(key, 1)}};
285
+ int64_t kStrideH = {{kernel.stride(key, 2)}};
286
+ int64_t vStrideB = {{kernel.stride(value, 0)}};
287
+ int64_t vStrideN = {{kernel.stride(value, 1)}};
288
+ int64_t vStrideH = {{kernel.stride(value, 2)}};
289
+ int64_t oStrideB = {{kernel.stride(output, 0)}};
290
+ int64_t oStrideM = {{kernel.stride(output, 2)}};
291
+ int64_t oStrideH = {{kernel.stride(output, 1)}};
292
+
293
+ int64_t kvSize = {{kernel.size(key, 1)}};
294
+
295
+ int64_t qSplitSize = qBlockSize;
296
+ int64_t kvSplitSize = kvBlockSize;
297
+
298
+
299
+ qSplitSize = qSplitSize > qSize ? qSize : qSplitSize;
300
+ kvSplitSize = kvSplitSize > kvSize ? kvSize : kvSplitSize;
301
+ int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize;
302
+ int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize;
303
+ int64_t kvTail = (kvSize - 1) % kvSplitSize + 1;
304
+
305
+ bool need_pack = false;
306
+ // Whether pack is needed for BFloat16/Half
307
+ if (is_reduced_type) {
308
+ // check platform ability
309
+ need_pack = std::is_same_v<scalar_t, at::BFloat16> ? at::native::cpublas::could_pack(at::kBFloat16)
310
+ : at::native::cpublas::could_pack(at::kHalf);
311
+ }
312
+ if (need_pack) {
313
+ // When the number of gemm is greater than the number of pack,
314
+ // the pack overhead can be overlapped.
315
+ int64_t thresh_size = 64;
316
+ need_pack = kvSize >= thresh_size && qSize >= thresh_size;
317
+ if (need_pack) {
318
+ double pack_size = batchSize * num_head * kvSize * headSize;
319
+ double qs_per_thread = (batchSize * num_head * qSlice + num_thread - 1) / num_thread;
320
+ double gemm_size_per_thread = qs_per_thread * qSplitSize * kvSize * headSize;
321
+ need_pack = gemm_size_per_thread / pack_size >= 4;
322
+ }
323
+ }
324
+ // Pad is needed for packing when K is not even
325
+ bool headSize_even = headSize % 2 == 0;
326
+ int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize;
327
+ int64_t ekvSplitSize = need_pack && (kvSplitSize % 2 != 0) ? kvSplitSize + 1 : kvSplitSize;
328
+ int64_t ekvTail = need_pack && (kvTail % 2 != 0) ? kvTail + 1 : kvTail;
329
+ int64_t kv_padding_size = (kvSize - 1) / kvSplitSize * ekvSplitSize + ekvTail;
330
+
331
+ // Allocate per thread temp buf (accumulate type)
332
+ int64_t _size_per_thread =
333
+ /* qk */ qSplitSize * kvSplitSize +
334
+ /* qk_max */ qSplitSize +
335
+ /* qk_sum */ qSplitSize +
336
+ /* dst */ qSplitSize * headSize_v;
337
+
338
+ // Inputs/outputs buffers
339
+ const scalar_t* q_data = query;
340
+ const scalar_t* k_data = key;
341
+ const scalar_t* v_data = value;
342
+ scalar_t* out_data = output;
343
+
344
+ // Buffers to store accum results, padding query and transpose/packing key/value
345
+ {{template.codegen_allocate_buffer("buf_data", "accum_t", "num_thread*_size_per_thread")}}
346
+ {{template.codegen_allocate_buffer("buf_reduced_data", "scalar_t", "num_thread*qSplitSize*ekvSplitSize")}}
347
+ {{template.codegen_allocate_buffer("key_reorder_ptr", "scalar_t", "batchSize_k*num_head_k*eheadSize*kvSize")}}
348
+ {{template.codegen_allocate_buffer("value_reorder_ptr", "scalar_t", "batchSize_k*num_head_k*kv_padding_size*headSize_v")}}
349
+ {{template.codegen_allocate_buffer("transpose_buffer_ptr", "scalar_t", "num_thread*kvSplitSize*headSize")}}
350
+ {{template.codegen_allocate_buffer("query_padding_ptr", "scalar_t", "num_thread*qSplitSize*eheadSize")}}
351
+ if (need_pack) {
352
+ // Pack K, V
353
+ at::parallel_for(0, batchSize_k * num_head_k * kvSlice, 1, [&](int64_t begin, int64_t end) {
354
+ int ompIdx = at::get_thread_num();
355
+ int64_t i = 0, j = 0, l = 0, n = 0;
356
+ scalar_t* transpose_ptr = need_pack? transpose_buffer_ptr + ompIdx * kvSplitSize * headSize : nullptr;
357
+ at::native::data_index_init(begin, i, batchSize_k, j, num_head_k, l, kvSlice);
358
+ for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
359
+ n = l * kvSplitSize;
360
+ int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n);
361
+ auto k_addr =
362
+ k_data + i * kStrideB + j * kStrideH + n * kStrideN;
363
+ auto v_addr =
364
+ v_data + i * vStrideB + j * vStrideH + n * vStrideN;
365
+ // transpose [cur_kvSplitSize, headSize] -> [headSize, cur_kvSplitSize]
366
+ at::native::utils::transpose<uint16_t>(
367
+ cur_kvSplitSize,
368
+ headSize,
369
+ /* src_ptr */
370
+ reinterpret_cast<const uint16_t*>(k_addr),
371
+ /* ld_src */ kStrideN,
372
+ /* dst */ reinterpret_cast<uint16_t*>(transpose_ptr),
373
+ /* ld_dst */ cur_kvSplitSize);
374
+
375
+ // Pack [headSize, cur_kvSplitSize]
376
+ at::vec::pack_vnni2(
377
+ /* src */ reinterpret_cast<const uint16_t*>(transpose_ptr),
378
+ /* dst */ reinterpret_cast<uint16_t*>(key_reorder_ptr + i * num_head_k * eheadSize * kvSize +
379
+ j * eheadSize * kvSize + n * eheadSize),
380
+ /* ld_src */ cur_kvSplitSize,
381
+ /* K */ headSize,
382
+ /* N */ cur_kvSplitSize);
383
+
384
+ // Pack [cur_kvSplitSize, headSize_v]
385
+ at::vec::pack_vnni2(
386
+ /* src */ reinterpret_cast<const uint16_t*>(v_addr),
387
+ /* dst */ reinterpret_cast<uint16_t*>(value_reorder_ptr +
388
+ i * num_head_k * kv_padding_size * headSize_v +
389
+ j * kv_padding_size * headSize_v + n * headSize_v),
390
+ /* ld_src */ vStrideN,
391
+ /* K */ cur_kvSplitSize,
392
+ /* N */ headSize_v);
393
+ // Move to the next query
394
+ at::native::data_index_step(i, batchSize_k, j, num_head_k, l, kvSlice);
395
+ }
396
+ });
397
+ }
398
+ // Attention loop below
399
+ at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
400
+ int64_t i = 0, j = 0, k = 0;
401
+ at::native::data_index_init(begin, i, batchSize, j, num_head, k, qSlice);
402
+ int ompIdx = at::get_thread_num();
403
+ accum_t* buf_ptr = buf_data + ompIdx * _size_per_thread;
404
+ accum_t* qk_data = buf_ptr;
405
+ accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize;
406
+ accum_t* qk_sum_data = qk_max_data + qSplitSize;
407
+ accum_t* dst_data = qk_sum_data + qSplitSize;
408
+ scalar_t *qk_reduced_data =
409
+ is_reduced_type
410
+ ? buf_reduced_data + ompIdx * qSplitSize * ekvSplitSize
411
+ : nullptr;
412
+ scalar_t* query_t_padding_ptr = (!headSize_even && need_pack)
413
+ ? query_padding_ptr + ompIdx * qSplitSize * eheadSize
414
+ : nullptr;
415
+
416
+ for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
417
+ auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i;
418
+ auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j;
419
+ auto kv_logical_num_data = kv_num_blocks_data + i_kvi * num_kviStrideB +
420
+ j_kvi * num_kviStrideH + k;
421
+ int kv_indice_num = *kv_logical_num_data;
422
+ std::vector<int> kv_indice_list(kv_indice_num);
423
+ for(int kv_i = 0; kv_i < kv_indice_num; kv_i++){
424
+ auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
425
+ j_kvi * kviStrideH + k*kviStrideQ + kv_i;
426
+ kv_indice_list[kv_i] = *kv_logical_data;
427
+ }
428
+ bool is_skip_kv = kv_indice_num > 0 ? false : true;
429
+ {%- if has_full_kv_block %}
430
+ auto full_kv_logical_num_data = full_kv_num_blocks_data + i_kvi * num_kviStrideB +
431
+ j_kvi * num_kviStrideH + k;
432
+ int full_kv_indice_num = *full_kv_logical_num_data;
433
+ std::vector<int> full_kv_indice_list(full_kv_indice_num);
434
+ for(int kv_i = 0; kv_i < full_kv_indice_num; kv_i++){
435
+ auto full_kv_logical_data = full_kv_indices_data + i_kvi * full_kviStrideB +
436
+ j_kvi * full_kviStrideH + k*full_kviStrideQ + kv_i;
437
+ full_kv_indice_list[kv_i] = *full_kv_logical_data;
438
+ }
439
+ is_skip_kv = kv_indice_num + full_kv_indice_num > 0 ? false : true;
440
+ {%- endif %}
441
+ int64_t m = k * qSplitSize;
442
+ int64_t cur_qSplitSize = std::min(qSplitSize, qSize - m);
443
+ if (!is_skip_kv){
444
+ // Initialize max and sum
445
+ {{kernel.kernel_name}}_fill_stub(qk_max_data,
446
+ -std::numeric_limits<accum_t>::infinity(), cur_qSplitSize);
447
+ {{kernel.kernel_name}}_fill_stub(qk_sum_data,
448
+ static_cast<accum_t>(0), cur_qSplitSize);
449
+
450
+ if (!headSize_even && need_pack) {
451
+ // Pad query if headSize is not even
452
+ {{kernel.kernel_name}}_copy_value_with_pad<scalar_t>(
453
+ q_data + i * qStrideB + j * qStrideH + m * qStrideM,
454
+ query_t_padding_ptr,
455
+ cur_qSplitSize,
456
+ headSize,
457
+ cur_qSplitSize,
458
+ eheadSize,
459
+ qStrideM
460
+ );
461
+ }
462
+ }
463
+
464
+ {%- if has_full_kv_block %}
465
+ for (int64_t n_idx = 0; n_idx < kv_indice_num + full_kv_indice_num ; n_idx += 1) {
466
+ auto n = n_idx < kv_indice_num ? kv_indice_list[n_idx]*kvSplitSize : full_kv_indice_list[n_idx - kv_indice_num]*kvSplitSize;
467
+ {%- else %}
468
+ for (int64_t n_idx = 0; n_idx < kv_indice_num ; n_idx += 1) {
469
+ auto n = kv_indice_list[n_idx]*kvSplitSize;
470
+ {%- endif %}
471
+
472
+ auto cur_n = n/kvSplitSize;
473
+ int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n);
474
+ int64_t cur_ekvSplitSize = (need_pack && cur_kvSplitSize % 2 != 0) ? cur_kvSplitSize + 1 : cur_kvSplitSize;
475
+
476
+ // Calculate scale * q @ k.T
477
+ auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i;
478
+ auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j;
479
+
480
+ if (!need_pack) {
481
+ auto k_addr =
482
+ k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
483
+
484
+ {{kernel.kernel_name}}_kernel_micro_gemm_transpose_b<static_cast<bool>(false)>(
485
+ q_data + i * qStrideB + j * qStrideH +
486
+ m * qStrideM,
487
+ k_addr,
488
+ qk_data,
489
+ cur_qSplitSize,
490
+ cur_kvSplitSize,
491
+ headSize,
492
+ qStrideM,
493
+ kStrideN,
494
+ cur_kvSplitSize);
495
+
496
+ } else {
497
+ at::native::cpublas::brgemm(
498
+ cur_qSplitSize,
499
+ cur_kvSplitSize,
500
+ eheadSize,
501
+ headSize_even ? qStrideM : eheadSize,
502
+ cur_kvSplitSize,
503
+ cur_kvSplitSize,
504
+ false,
505
+ !headSize_even
506
+ ? query_t_padding_ptr
507
+ : q_data + i * qStrideB + j * qStrideH + m * qStrideM,
508
+ key_reorder_ptr + i_kv * num_head_k * eheadSize * kvSize +
509
+ j_kv * eheadSize * kvSize + n * eheadSize,
510
+ qk_data,
511
+ need_pack);
512
+ }
513
+
514
+ {{kernel.kernel_name}}_mul_scale_kernel<accum_t>(qk_data, scaling_factor, cur_qSplitSize*cur_kvSplitSize);
515
+
516
+ {%- if score_mod and mask_mod %}
517
+ // TODO: reduce the number of calls of q_idx and kv_idx initialization
518
+ std::vector<int64_t> q_idx(cur_qSplitSize);
519
+ for (int64_t i = 0; i < cur_qSplitSize; ++i) {
520
+ q_idx[i] = m + i;
521
+ }
522
+
523
+ std::vector<int64_t> kv_idx(cur_kvSplitSize);
524
+ for (int64_t i = 0; i < cur_kvSplitSize; ++i) {
525
+ kv_idx[i] = n + i;
526
+ }
527
+
528
+ std::vector<int64_t> b_idx = {i};
529
+ std::vector<int64_t> h_idx = {j};
530
+
531
+ accum_t* in_ptr0 = qk_data;
532
+
533
+ auto in_ptr1 = b_idx.data();
534
+ auto in_ptr2 = h_idx.data();
535
+ auto in_ptr3 = q_idx.data();
536
+ auto in_ptr4 = kv_idx.data();
537
+
538
+ // apply score mod function
539
+ {
540
+ {{ template.generate_other_buffer("score_others", 0, "len_score_other", kernel.args) }}
541
+ accum_t* out_ptr{{score_buf_idx}} = in_ptr0;
542
+ {{ template.modification(score_mod, score_buf_name, score_buf_idx)|indent(12, false) }}
543
+ }
544
+
545
+ if ((std::find(kv_indice_list.begin(), kv_indice_list.end(), cur_n) != kv_indice_list.end()) ){
546
+ // Apply block mask, fill unused with -inf
547
+ {
548
+ {{ template.generate_other_buffer("mask_others", -1, "len_mask_other", kernel.args) }}
549
+ accum_t* out_ptr{{mask_buf_idx}} = in_ptr0;
550
+ {{ template.modification(mask_mod, mask_buf_name, mask_buf_idx)|indent(12, false) }}
551
+ }
552
+ }
553
+
554
+ {%- endif %}
555
+ // Update coefficients with Softmax
556
+ accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0;
557
+ for (int64_t row = 0; row < cur_qSplitSize; ++row) {
558
+ // apply scaling factor and max per row in fusion
559
+ {{kernel.kernel_name}}_mul_reduce_max_fusion_kernel(
560
+ qk_data + row * cur_kvSplitSize,
561
+ static_cast<accum_t>(1),
562
+ cur_kvSplitSize,
563
+ qk_data + row * cur_kvSplitSize,
564
+ tmp_max);
565
+ tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
566
+ if (tmp_max == -std::numeric_limits<accum_t>::infinity()) {
567
+ // to avoid `nan = exp2f(-inf - (-inf))`
568
+ {{kernel.kernel_name}}_fill_stub(
569
+ {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data) + row * cur_ekvSplitSize,
570
+ static_cast<scalar_t>(0), cur_kvSplitSize);
571
+ } else {
572
+ tmp_sum = tmp_max;
573
+ // qk <- exp(qk - max) and sum per row
574
+ {{kernel.kernel_name}}_exp_reduce_sum_fusion_kernel(
575
+ qk_data + row * cur_kvSplitSize, cur_kvSplitSize,
576
+ {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data) + row * cur_ekvSplitSize,
577
+ tmp_sum);
578
+ // exp_tmp <- exp(max[row] - max)
579
+ exp_tmp = std::exp(qk_max_data[row] - tmp_max);
580
+ // sum[row] <- sum + exp_tmp * sum[row]
581
+ qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
582
+ // max[row] <- max
583
+ qk_max_data[row] = tmp_max;
584
+ // dst <- dst * exp_tmp
585
+ if (n_idx > 0) {
586
+ at::vec::map<accum_t>(
587
+ [exp_tmp](Vec x) { return x * Vec(exp_tmp); },
588
+ dst_data + row * headSize_v,
589
+ dst_data + row * headSize_v,
590
+ headSize_v);
591
+ }
592
+ }
593
+ if (need_pack && cur_kvSplitSize % 2 != 0) {
594
+ // Pad: [qSplitSize, cur_kvSplitSize] -> [qSplitSize, cur_kvSplitSize + 1]
595
+ *(qk_reduced_data + row * (1 + cur_kvSplitSize) + cur_kvSplitSize) = scalar_t(0);
596
+ }
597
+ }
598
+ // Calculate Softmax(q @ k.T) @ v
599
+ if (!need_pack) {
600
+ auto v_addr =
601
+ v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN;
602
+ // Fallback Half brgemm is slower than micro gemm
603
+ if (!std::is_same_v<scalar_t, at::Half>) {
604
+ at::native::cpublas::brgemm(
605
+ cur_qSplitSize,
606
+ headSize_v,
607
+ cur_ekvSplitSize,
608
+ cur_ekvSplitSize,
609
+ vStrideN,
610
+ headSize_v,
611
+ n_idx > 0,
612
+ {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data),
613
+ v_addr,
614
+ dst_data,
615
+ need_pack);
616
+ } else {
617
+ if (n_idx > 0) {
618
+ {{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(true)>(
619
+ {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data),
620
+ v_addr,
621
+ dst_data,
622
+ cur_qSplitSize,
623
+ headSize_v,
624
+ cur_ekvSplitSize,
625
+ cur_ekvSplitSize,
626
+ vStrideN,
627
+ headSize_v);
628
+ } else {
629
+ {{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(false)>(
630
+ {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data),
631
+ v_addr,
632
+ dst_data,
633
+ cur_qSplitSize,
634
+ headSize_v,
635
+ cur_ekvSplitSize,
636
+ cur_ekvSplitSize,
637
+ vStrideN,
638
+ headSize_v);
639
+ }
640
+ }
641
+ } else {
642
+ int64_t psize = n / kvSplitSize * ekvSplitSize;
643
+ at::native::cpublas::brgemm(
644
+ cur_qSplitSize,
645
+ headSize_v,
646
+ cur_ekvSplitSize,
647
+ cur_ekvSplitSize,
648
+ headSize_v,
649
+ headSize_v,
650
+ n_idx > 0,
651
+ qk_reduced_data,
652
+ value_reorder_ptr +
653
+ i_kv * num_head_k * kv_padding_size * headSize_v +
654
+ j_kv * kv_padding_size * headSize_v + psize * headSize_v,
655
+ dst_data,
656
+ need_pack);
657
+ }
658
+ }
659
+
660
+ // dst <- dst / sum[row]
661
+ // reorder MHA output with strides
662
+ for (int64_t row = 0; row < cur_qSplitSize; ++row) {
663
+ // Row sums for full masked out rows are 0, we set them to 1
664
+ // in order to avoid NaNs in the output and instead set fully
665
+ // masked out rows to 0
666
+ qk_max_data[row] = qk_max_data[row] == -std::numeric_limits<accum_t>::infinity() ? 0 : qk_max_data[row];
667
+ qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row];
668
+ accum_t sum_reciprocal = 1 / qk_sum_data[row];
669
+ at::vec::map<scalar_t>(
670
+ [sum_reciprocal, is_skip_kv](Vec x) { return is_skip_kv ? Vec(0.0) : x * Vec(sum_reciprocal); },
671
+ out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM,
672
+ dst_data + row * headSize_v,
673
+ headSize_v);
674
+ }
675
+
676
+ // Move to the next query
677
+ at::native::data_index_step(i, batchSize, j, num_head, k, qSlice);
678
+ }
679
+
680
+ at::native::cpublas::brgemm_release(need_pack);
681
+
682
+ });
683
+ }
684
+ """
685
+
686
+
687
+ class CppFlexAttentionTemplate(CppTemplate):
688
+ def __init__(
689
+ self,
690
+ input_nodes,
691
+ layout: ir.Layout,
692
+ scale,
693
+ score_mod,
694
+ mask_mod,
695
+ kv_block_size,
696
+ q_block_size,
697
+ has_other_buffer,
698
+ no_full_kv_block,
699
+ fake_buffers,
700
+ len_score_other,
701
+ len_mask_other,
702
+ kernel_input_name_to_buffer,
703
+ block_vars,
704
+ ) -> None:
705
+ assert layout.dtype in [torch.float, torch.bfloat16, torch.float16]
706
+ super().__init__("flex_attention", input_nodes, layout, parallel_num_threads())
707
+ self.scale = scale
708
+ self.score_mod = score_mod
709
+ self.mask_mod = mask_mod
710
+ self.score_buf_name = (
711
+ V.graph.register_buffer(self.score_mod) if self.score_mod else None
712
+ )
713
+ self.mask_buf_name = (
714
+ V.graph.register_buffer(self.mask_mod) if self.mask_mod else None
715
+ )
716
+
717
+ def get_idx(buf_name):
718
+ match = re.search(r"\d+", buf_name)
719
+ assert match, f"incorrect score buf name: {buf_name}"
720
+ return match.group()
721
+
722
+ self.score_buf_idx = (
723
+ get_idx(self.score_buf_name) if self.score_buf_name else None
724
+ )
725
+ self.mask_buf_idx = get_idx(self.mask_buf_name) if self.mask_buf_name else None
726
+ self.kv_block_size = kv_block_size
727
+ self.q_block_size = q_block_size
728
+ self.has_other_buffer = has_other_buffer
729
+ self.no_full_kv_block = no_full_kv_block
730
+ self.other_buffer_input_offset = 2
731
+ if self.no_full_kv_block:
732
+ self.other_buffer_input_offset = 0
733
+ self.fake_buffers = fake_buffers
734
+ self.len_score_other = len_score_other
735
+ self.len_mask_other = len_mask_other
736
+ self.kernel_input_name_to_buffer = kernel_input_name_to_buffer
737
+ self.block_vars = block_vars
738
+ self.extra_sizevars = list(
739
+ OrderedSet(
740
+ val
741
+ for val in self.kernel_input_name_to_buffer.values()
742
+ if isinstance(val, sympy.Symbol)
743
+ )
744
+ )
745
+ self.other_buf_start_idx = 5
746
+ self.score_mod_other_buffers = (
747
+ self.input_nodes[
748
+ self.other_buf_start_idx
749
+ + self.other_buffer_input_offset : self.other_buf_start_idx
750
+ + self.other_buffer_input_offset
751
+ + self.len_score_other
752
+ ]
753
+ if self.has_other_buffer
754
+ else None
755
+ )
756
+ self.mask_mod_other_buffers = (
757
+ self.input_nodes[
758
+ self.other_buf_start_idx
759
+ + self.other_buffer_input_offset
760
+ + self.len_score_other :
761
+ ]
762
+ if self.has_other_buffer
763
+ else None
764
+ )
765
+ self.other_ptr_data = {} # type: ignore[var-annotated]
766
+
767
+ def update_kernel_args(self, kernel_args):
768
+ kernel_args.update(
769
+ {
770
+ key: value
771
+ for key, value in self.kernel_input_name_to_buffer.items()
772
+ if not isinstance(value, sympy.Symbol)
773
+ }
774
+ )
775
+ return kernel_args
776
+
777
+ def generate_other_buffer(self, buf_list, start_offset, len_attr, kernel_args):
778
+ kernel_input_name_to_buffer_name = {
779
+ key: value if isinstance(value, sympy.Symbol) else value.get_name()
780
+ for key, value in self.kernel_input_name_to_buffer.items()
781
+ }
782
+
783
+ def get_arg(name):
784
+ return kernel_input_name_to_buffer_name.get(name)
785
+
786
+ def get_arg_name(name):
787
+ if isinstance(get_arg(name), sympy.Symbol):
788
+ return kernel_args.sizevars.get(get_arg(name))
789
+ return kernel_args.input_buffers.get(get_arg(name))
790
+
791
+ if not self.has_other_buffer:
792
+ return ""
793
+
794
+ if start_offset == -1:
795
+ start_offset = getattr(self, len_attr)
796
+
797
+ length = getattr(self, len_attr)
798
+ for i in range(length):
799
+ pointer = f"in_ptr{self.other_buf_start_idx + start_offset + i}"
800
+ buffer_key = f"{buf_list}_{i}"
801
+ if pointer not in self.other_ptr_data:
802
+ self.other_ptr_data[pointer] = (
803
+ get_arg_name(buffer_key),
804
+ get_arg(buffer_key),
805
+ )
806
+
807
+ return "\n".join(
808
+ f"auto {ptr} = {name};" for ptr, (name, _) in self.other_ptr_data.items()
809
+ )
810
+
811
+ def modification(self, subgraph_buffer, output_name, output_idx):
812
+ assert isinstance(subgraph_buffer, ir.ComputedBuffer)
813
+ subgraph_buffer_data = subgraph_buffer.data
814
+ from ..loop_body import LoopBody
815
+ from ..utils import sympy_index_symbol_with_prefix, SymT
816
+ from ..virtualized import V
817
+ from .cpp import CppKernelProxy, KernelGroup
818
+
819
+ kernel_group = KernelGroup()
820
+ kernel_input_args = {
821
+ "score": "in_ptr0",
822
+ "b": "in_ptr1",
823
+ "h": "in_ptr2",
824
+ "q_idx": "in_ptr3",
825
+ "kv_idx": "in_ptr4",
826
+ }
827
+ if self.has_other_buffer:
828
+ kernel_input_args.update(
829
+ {arg: ptr for ptr, (_, arg) in self.other_ptr_data.items()}
830
+ )
831
+
832
+ kernel_output_args = {output_name: f"out_ptr{output_idx}"}
833
+
834
+ args = kernel_group.args
835
+ for name, inp in kernel_input_args.items():
836
+ args.input_buffers[name] = inp
837
+
838
+ for name, inp in kernel_output_args.items():
839
+ args.output_buffers[name] = inp
840
+
841
+ for name in self.extra_sizevars:
842
+ args.sizevars[name] = f"k{name}"
843
+
844
+ kernel_group.args = args
845
+
846
+ cpp_kernel_proxy = CppKernelProxy(kernel_group)
847
+ bodies = []
848
+ var_sizes_list = []
849
+ var_sizes = tuple(subgraph_buffer.get_size())
850
+ var_ranges = {
851
+ sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
852
+ for i, sz in enumerate(var_sizes)
853
+ }
854
+
855
+ dst_layout = subgraph_buffer.get_layout()
856
+ output_index = dst_layout.make_indexer()([*var_ranges.keys()])
857
+
858
+ def fn(*args):
859
+ V.ops.store(
860
+ output_name,
861
+ output_index,
862
+ subgraph_buffer_data.make_loader()(args).value,
863
+ )
864
+
865
+ body = LoopBody(
866
+ fn,
867
+ (list(var_ranges.keys())),
868
+ var_ranges,
869
+ list(var_ranges.keys()),
870
+ tuple(),
871
+ )
872
+
873
+ from ..loop_body import MemoryUsageType
874
+
875
+ assert all(
876
+ mem.buffer_name in kernel_group.args.input_buffers
877
+ for mem in body.memory_usage[MemoryUsageType.LOAD]
878
+ ), (
879
+ "All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers"
880
+ )
881
+
882
+ bodies.append(body)
883
+ var_sizes_list.append((var_sizes, ()))
884
+
885
+ cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
886
+ kernel_group.finalize_kernel(cpp_kernel_proxy, [])
887
+ output_code = kernel_group.loops_code.getvalue()
888
+
889
+ var_q_symbol, var_kv_symbol = self.block_vars
890
+ # See [Note] Handle the case where the split sizes are not statically known.
891
+ # We don't know the value of qBlockSize and rkvBlockSize during compilation time
892
+ # thus we've represented them by symbols.
893
+ # We change the symbol strings back to "cur_qSplitSize" and "cur_kvSplitSize"
894
+ # in the generated code thus they'll be filled with the real value during runtime.
895
+ if var_q_symbol in kernel_group.args.sizevars:
896
+ output_code = output_code.replace(
897
+ kernel_group.args.sizevars[var_q_symbol], "cur_qSplitSize"
898
+ )
899
+ if var_kv_symbol in kernel_group.args.sizevars:
900
+ output_code = output_code.replace(
901
+ kernel_group.args.sizevars[var_kv_symbol], "cur_kvSplitSize"
902
+ )
903
+
904
+ return output_code
905
+
906
+ @staticmethod
907
+ def add_choices(
908
+ choices,
909
+ input_nodes,
910
+ layout,
911
+ scale,
912
+ score_mod,
913
+ mask_mod,
914
+ kv_block_size,
915
+ q_block_size,
916
+ has_other_buffer,
917
+ no_full_kv_block,
918
+ fake_buffers,
919
+ len_score_other,
920
+ len_mask_other,
921
+ kernel_input_name_to_buffer,
922
+ block_vars,
923
+ ):
924
+ def preprocessor(input_nodes, layout):
925
+ return input_nodes, layout
926
+
927
+ def postprocessor(output):
928
+ return output
929
+
930
+ template = DataProcessorTemplateWrapper(
931
+ CppFlexAttentionTemplate,
932
+ preprocessor,
933
+ postprocessor,
934
+ input_nodes=input_nodes,
935
+ layout=layout,
936
+ scale=scale,
937
+ score_mod=score_mod,
938
+ mask_mod=mask_mod,
939
+ kv_block_size=kv_block_size,
940
+ q_block_size=q_block_size,
941
+ has_other_buffer=has_other_buffer,
942
+ no_full_kv_block=no_full_kv_block,
943
+ fake_buffers=fake_buffers,
944
+ len_score_other=len_score_other,
945
+ len_mask_other=len_mask_other,
946
+ kernel_input_name_to_buffer=kernel_input_name_to_buffer,
947
+ block_vars=block_vars,
948
+ )
949
+ template.maybe_append_choice(choices)
950
+ return template
951
+
952
+ def apply_score_mod(self, score, b, h, q_idx, kv_idx):
953
+ return self.score_mod.graph_module(score, b, h, q_idx, kv_idx).item()
954
+
955
+ def render( # type: ignore[override,return]
956
+ self,
957
+ kernel,
958
+ template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
959
+ epilogue_nodes: Optional[list[ir.IRNode]] = None,
960
+ **kwargs,
961
+ ) -> str:
962
+ if epilogue_nodes is not None and epilogue_nodes != []:
963
+ raise NotImplementedError(
964
+ "Unsupported for `epilogue_nodes` in CppFlexAttentionTemplate."
965
+ )
966
+ # Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
967
+ # -> (Batch x Q_seq_len x Num_heads x Dim_per_head)
968
+ # Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
969
+ # -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
970
+ # Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
971
+ # -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
972
+
973
+ query = kernel.permute(self.input_nodes[0], [0, 2, 1, 3])
974
+ key = kernel.permute(self.input_nodes[1], [0, 2, 1, 3])
975
+ value = kernel.permute(self.input_nodes[2], [0, 2, 1, 3])
976
+ self.accumulate_dtype = torch.float
977
+ self.input_dtype = query.layout.dtype
978
+
979
+ num_threads = parallel_num_threads()
980
+ buf_out = TensorBox.create(self.output_node)
981
+ if template_buffer_node is not None:
982
+ buf_out = template_buffer_node
983
+ options = dict(
984
+ query=query,
985
+ key=key,
986
+ value=value,
987
+ kv_num_blocks=self.input_nodes[3],
988
+ kv_indices=self.input_nodes[4],
989
+ full_kv_num_blocks=self.input_nodes[5]
990
+ if not self.no_full_kv_block
991
+ else None,
992
+ full_kv_indices=self.input_nodes[6] if not self.no_full_kv_block else None,
993
+ score_mod_other_buffers=self.score_mod_other_buffers,
994
+ mask_mod_other_buffers=self.mask_mod_other_buffers,
995
+ scale=self.scale,
996
+ has_full_kv_block=not self.no_full_kv_block,
997
+ accumulate_dtype=self.accumulate_dtype,
998
+ query_dtype=self.input_dtype,
999
+ kvBlockSize=self.kv_block_size,
1000
+ qBlockSize=self.q_block_size,
1001
+ template=self,
1002
+ output=buf_out,
1003
+ kernel=kernel,
1004
+ num_thread=num_threads,
1005
+ score_mod=self.score_mod,
1006
+ mask_mod=self.mask_mod,
1007
+ score_buf_name=self.score_buf_name,
1008
+ mask_buf_name=self.mask_buf_name,
1009
+ score_buf_idx=self.score_buf_idx,
1010
+ mask_buf_idx=self.mask_buf_idx,
1011
+ )
1012
+ with contextlib.ExitStack() as stack:
1013
+ for buf in self.fake_buffers:
1014
+ stack.enter_context(
1015
+ patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
1016
+ )
1017
+ return self._template_from_string(FLEX_ATTENTION_TEMPLATE).render(**options)
1018
+
1019
+ def codegen_softmax_fusion(self, kernel_name: str):
1020
+ # TODO: use inductor IR to rewrite those fusions
1021
+ return self._template_from_string(SOFTMAX_FUSIONS).render(
1022
+ dict(kernel_name=kernel_name)
1023
+ )
1024
+
1025
+ def codegen_brgemm_pack_function(self, kernel_name: str):
1026
+ # TODO: make them general for common bmm templates
1027
+ return self._template_from_string(BRGEMM_PACK_FUNCTIONS).render(
1028
+ dict(kernel_name=kernel_name)
1029
+ )
1030
+
1031
+ def codegen_allocate_buffer(self, buffer_name: str, buffer_dtype, buffer_size):
1032
+ return self._template_from_string(ALLOCATE_BUFFER).render(
1033
+ dict(
1034
+ buffer_name=buffer_name,
1035
+ buffer_dtype=buffer_dtype,
1036
+ buffer_size=buffer_size,
1037
+ )
1038
+ )
1039
+
1040
+ def micro_gemm_define(self, kernel_name: str):
1041
+ from torch._inductor.codegen.cpp_gemm_template import (
1042
+ CppTemplateKernel,
1043
+ parallel_num_threads,
1044
+ )
1045
+ from torch._inductor.codegen.cpp_micro_gemm import CppMicroGemmFP32Vec
1046
+ from torch._inductor.virtualized import V
1047
+
1048
+ micro_gemm_trans = CppMicroGemmFP32Vec(
1049
+ kernel_name + "_kernel_micro_gemm_transpose_b",
1050
+ self.input_dtype,
1051
+ self.input_dtype,
1052
+ self.accumulate_dtype,
1053
+ self.accumulate_dtype,
1054
+ GemmBlocking(1, 16, 1),
1055
+ 1,
1056
+ True,
1057
+ True,
1058
+ )
1059
+
1060
+ micro_gemm = CppMicroGemmFP32Vec(
1061
+ kernel_name + "_kernel_micro_gemm",
1062
+ self.input_dtype,
1063
+ self.input_dtype,
1064
+ self.accumulate_dtype,
1065
+ self.accumulate_dtype,
1066
+ GemmBlocking(1, 16, 1),
1067
+ 1,
1068
+ True,
1069
+ False,
1070
+ )
1071
+
1072
+ with V.set_graph_handler(V.graph):
1073
+ kernel = CppTemplateKernel("cpp_micro_gemm", parallel_num_threads())
1074
+ code_trans = micro_gemm_trans.codegen_define(kernel)
1075
+ code = micro_gemm.codegen_define(kernel)
1076
+ return code + code_trans
1077
+
1078
+ def codegen_micro_gemm(self, kernel_name: str):
1079
+ micro_gemm = self.micro_gemm_define(kernel_name)
1080
+ GEMM_SOURCE_CODE = MICRO_GEMM_TEMPLATE.replace("GEMM_DEFINE", micro_gemm)
1081
+ return self._template_from_string(GEMM_SOURCE_CODE).render()
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_gemm_template.py ADDED
@@ -0,0 +1,1777 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import contextlib
3
+ import logging
4
+ import math
5
+ from functools import lru_cache
6
+ from typing import Any, Callable, cast, Optional, TypeVar, Union
7
+ from unittest.mock import patch
8
+
9
+ import torch
10
+ import torch.utils
11
+ from torch.utils._ordered_set import OrderedSet
12
+
13
+ from ..._dynamo.utils import counters
14
+ from .. import config, ir, lowering as L
15
+ from ..kernel.mm_common import mm_args
16
+ from ..select_algorithm import DataProcessorTemplateWrapper
17
+ from ..utils import (
18
+ has_free_symbols,
19
+ is_same_mkldnn_tensor,
20
+ is_same_tensor,
21
+ parallel_num_threads,
22
+ )
23
+ from ..virtualized import ops, V
24
+ from .cpp import get_export_declaration
25
+ from .cpp_micro_gemm import (
26
+ CppMicroBrgemm,
27
+ CppMicroGemm,
28
+ CppMicroGemmAMX,
29
+ CppMicroGemmFP32Vec,
30
+ create_micro_gemm,
31
+ is_int8_woq_gemm_small_m_dim_corner_case,
32
+ LayoutType,
33
+ )
34
+ from .cpp_template import CppTemplate
35
+ from .cpp_template_kernel import CppTemplateKernel
36
+ from .cpp_utils import (
37
+ create_epilogue_with_attr,
38
+ DTYPE_TO_CPP,
39
+ GemmBlocking,
40
+ get_gemm_template_output_and_compute_dtype,
41
+ )
42
+
43
+
44
+ log = logging.getLogger(__name__)
45
+
46
+ GEMM_TEMPLATE_INIT_BLOCKING_BASIC_BLOCK = r"""
47
+ constexpr int64_t num_threads = {{num_threads}};
48
+ constexpr int64_t N = {{N}};
49
+ constexpr int64_t K = {{K}};
50
+ constexpr int64_t Mr = {{micro_gemm.register_blocking.block_m}};
51
+ constexpr int64_t Nr = {{micro_gemm.register_blocking.block_n}};
52
+ constexpr int64_t Kr = {{micro_gemm.register_blocking.block_k}};
53
+ constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
54
+ constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
55
+ {%- if is_dynamic_M %}
56
+ const int64_t M = {{kernel.size(GemmOut, 0)}};
57
+ const int64_t Mr_blocks = (M + Mr - 1) / Mr;
58
+ {%- else %}
59
+ constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
60
+ constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
61
+ {%- endif %}
62
+ """
63
+
64
+ GEMM_TEMPLATE_INIT_BLOCKING_EXTENDED = r"""
65
+ {%- if is_dynamic_M %}
66
+ {%- if num_threads > 1 %}
67
+ int64_t Mt_blocks, Nt_blocks, Kt_blocks;
68
+ mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks);
69
+ {%- else %}
70
+ const auto Mt_blocks = Mr_blocks;
71
+ const auto Nt_blocks = Nr_blocks;
72
+ const auto Kt_blocks = Kr_blocks;
73
+ {%- endif %}
74
+ int64_t Mc_blocks, Nc_blocks, Kc_blocks;
75
+ uint32_t L1_cache_size = {{L1_cache_size}};
76
+ uint32_t L2_cache_size = {{L2_cache_size}};
77
+ mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>(
78
+ num_threads,
79
+ M,
80
+ N,
81
+ K,
82
+ Mr,
83
+ Nr,
84
+ Kr,
85
+ Mt_blocks,
86
+ Nt_blocks,
87
+ Kt_blocks,
88
+ Mc_blocks,
89
+ Nc_blocks,
90
+ Kc_blocks,
91
+ L1_cache_size,
92
+ L2_cache_size
93
+ );
94
+ const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
95
+ const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
96
+ const int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
97
+ const int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
98
+ const int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
99
+ {%- else %}
100
+ constexpr int64_t Mt_blocks = {{template.thread_blocking(num_threads).block_m}};
101
+ constexpr int64_t Nt_blocks = {{template.thread_blocking(num_threads).block_n}};
102
+ constexpr int64_t Kt_blocks = {{template.thread_blocking(num_threads).block_k}};
103
+ constexpr int64_t Mc_blocks = {{template.cache_blocking(num_threads).block_m}};
104
+ constexpr int64_t Nc_blocks = {{template.cache_blocking(num_threads).block_n}};
105
+ constexpr int64_t Kc_blocks = {{template.cache_blocking(num_threads).block_k}};
106
+ constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
107
+ constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
108
+ constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
109
+ constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
110
+ constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
111
+ {%- endif %}
112
+ {%- if is_woq_int4 %}
113
+ int64_t group_size = *q_group_size;
114
+ {%- endif %}
115
+
116
+ // make sure all partitions are assigned
117
+ {{kernel.assert_function}}(
118
+ Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= Mr_blocks * Nr_blocks * Kr_blocks,
119
+ "Not all partitions are assigned."
120
+ );
121
+ """
122
+
123
+ GEMM_TEMPLATE_MULTI_THREADS_PARAMS = r"""
124
+ const int tid = omp_get_thread_num();
125
+ const int64_t k_group_id = tid / num_Kt_blocks;
126
+ const int64_t k_slice_id = tid % num_Kt_blocks;
127
+ const int64_t n_group_id = k_group_id / num_Nt_blocks;
128
+ const int64_t n_slice_id = k_group_id % num_Nt_blocks;
129
+ const int64_t k_block_start = k_slice_id * Kt_blocks;
130
+ const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks);
131
+ const int64_t n_block_start = n_slice_id * Nt_blocks;
132
+ const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks);
133
+ const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks);
134
+ const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks);
135
+ const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks;
136
+ """
137
+
138
+ GEMM_TEMPLATE_SINGLE_THREAD_PARAMS = r"""
139
+ constexpr int tid = 0;
140
+ constexpr int64_t k_group_id = 0;
141
+ constexpr int64_t k_slice_id = 0;
142
+ constexpr int64_t n_group_id = 0;
143
+ constexpr int64_t n_slice_id = 0;
144
+ constexpr int64_t m_block_start = 0;
145
+ constexpr int64_t n_block_start = 0;
146
+ constexpr int64_t n_block_end = Nr_blocks;
147
+ constexpr int64_t k_block_start = 0;
148
+ constexpr int64_t k_block_end = Kr_blocks;
149
+ {%- if is_dynamic_M %}
150
+ const int64_t num_Mc_blocks_per_thread = num_Mc_blocks;
151
+ const int64_t m_block_end = Mr_blocks;
152
+ {%- else %}
153
+ constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks;
154
+ constexpr int64_t m_block_end = Mr_blocks;
155
+ {%- endif %}
156
+ """
157
+
158
+ GEMM_TEMPLATE_M_LOOP_PARAMS = r"""
159
+ const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread;
160
+ const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks;
161
+ const int64_t m_start = mc * Mr;
162
+ const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
163
+ const int64_t m_size = m_end - m_start;
164
+ """
165
+
166
+ GEMM_TEMPLATE_N_LOOP_PARAMS = r"""
167
+ const int64_t n_start = nc * Nr;
168
+ const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
169
+ const int64_t n_size = n_end - n_start;
170
+ // NB: assume we pad N, nc_block_end won't exceed padded N here.
171
+ const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
172
+ """
173
+
174
+ GEMM_TEMPLATE_MICROKERNEL_DEF = r"""
175
+ {{template.header().getvalue()}}
176
+
177
+ {{micro_gemm.codegen_define(kernel)}}
178
+ """
179
+
180
+ GEMM_TEMPLATE_STUB_DEF = r"""
181
+ {%- if x_scale is not none %}
182
+ {%- set kernel_args = {"X": X, "W": W, "inp": inp, "x_scale": x_scale, "x_zp": x_zp, "w_scale": w_scale, "w_zp": w_zp,} %}
183
+ {%- elif is_woq_int4 %}
184
+ {%- set kernel_args = {"X": X, "W": W, "q_group_size": q_group_size, "qscale_and_zeros": qscale_and_zeros} %}
185
+ {%- else %}
186
+ {%- set kernel_args = {"X": X, "W": W, "inp": inp} %}
187
+ {%- endif %}
188
+
189
+ extern "C" {{export_declaration}}
190
+ {{kernel.def_kernel(inputs=kernel_args, outputs={"Y": Y}, aliases=aliases)}}
191
+ """
192
+
193
+ GEMM_TEMPLATE = r"""
194
+ {{ template.codegen_gemm_stub_def() }}
195
+ {
196
+ {{ kernel.maybe_codegen_profile() }}
197
+ {{ template.codegen_blocks(
198
+ num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W
199
+ ) }}
200
+
201
+ {%- if maybe_k_slicing %}
202
+ std::unique_ptr<std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[]> local_buf_ptrs;
203
+ if (num_Kt_blocks > 1) {
204
+ local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_Kt_blocks]);
205
+ }
206
+ {%- endif %}
207
+
208
+ {%- if num_threads > 1 %}
209
+ #pragma omp parallel num_threads({{num_threads}})
210
+ {
211
+ {{ template.codegen_multi_threads_params()|indent(8, false) }}
212
+ {%- else %}
213
+ {
214
+ {{ template.codegen_single_thread_params(is_dynamic_M)|indent(8, false) }}
215
+ {%- endif %}
216
+ {{ micro_gemm.codegen_init(kernel) }}
217
+ {%- if use_local_acc %}
218
+ {%- set acc_buf_name = "local_acc_buf" %}
219
+ {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
220
+ {%- endif %}
221
+ for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
222
+ {{ template.codegen_m_loop_params()|indent(12, false) }}
223
+ for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
224
+ {{ template.codegen_n_loop_params()|indent(16, false) }}
225
+ {%- if use_local_acc %}
226
+ {%- set acc = kernel.local_buffers[acc_buf_name] %}
227
+ {{ kernel.reinit_buffer_if_null(acc_buf_name) }}
228
+ {%- else %}
229
+ {%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_end")]) %}
230
+ {%- endif %}
231
+ for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
232
+ int64_t k_start = kc * Kr;
233
+ int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
234
+ {%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %}
235
+ for (int64_t nci = nc; nci < nc_block_end; nci++) {
236
+ {%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %}
237
+ {%- if template.should_block_weights and not is_woq_int4 %}
238
+ {%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %}
239
+ {%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
240
+ {%- else %}
241
+ {%- if is_woq_int4 %}
242
+ {%- set tile_W = kernel.slice_nd(W, [("nci * Nr", "(nci + 1) * Nr"), ("k_start * Nr / 2", "k_end * Nr / 2")]) %}
243
+ {%- set tile_qparam = kernel.slice_nd(
244
+ qscale_and_zeros, [("k_start // group_size", "k_end // group_size"), ("nci * Nr", "(nci + 1) * Nr"), ()]) %}
245
+ {%- else %}
246
+ {%- set tile_W = kernel.slice_nd(W, [("k_start", "k_end"), ("n_start", "n_start + n_size")]) %}
247
+ {%- set tile_qparam = None %}
248
+ {%- endif %}
249
+ {%- endif %}
250
+ if (kc == k_block_start) {
251
+ {{ micro_gemm.codegen_call(kernel,
252
+ tile_X,
253
+ tile_W,
254
+ acc_slice,
255
+ accum=False,
256
+ qscale_and_zeros=tile_qparam)|indent(28, false)
257
+ }}
258
+ } else {
259
+ {{ micro_gemm.codegen_call(kernel,
260
+ tile_X,
261
+ tile_W,
262
+ acc_slice,
263
+ accum=True,
264
+ qscale_and_zeros=tile_qparam)|indent(28, false)
265
+ }}
266
+ }
267
+ }
268
+ }
269
+ {%- if maybe_k_slicing %}
270
+ if (num_Kt_blocks > 1) {
271
+ const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
272
+ local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + k_slice_id].reset(
273
+ {{ kernel.release_buffer(acc_buf_name) }});
274
+ } else
275
+ {%- endif %}
276
+ {
277
+ {%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %}
278
+ {%- set tile_acc = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("0", "n_end - n_start")]) %}
279
+ {{ kernel.store_output(
280
+ tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
281
+ )|indent(20, false)
282
+ }}
283
+ }
284
+ }
285
+ }
286
+ {%- if maybe_k_slicing %}
287
+ if (num_Kt_blocks > 1) {
288
+ #pragma omp barrier
289
+ for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
290
+ // We slice M-dim and each thread in the k-slicing group works on a slice
291
+ const int64_t m_start_unsliced = mc * Mr;
292
+ const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
293
+ const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced;
294
+ const int64_t m_slice_size = (m_size_unsliced + num_Kt_blocks - 1) / num_Kt_blocks;
295
+ const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced);
296
+ const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced);
297
+ const int64_t m_size = m_end - m_start;
298
+ const int64_t m_offset = m_start - m_start_unsliced;
299
+ for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
300
+ const int64_t n_start = nc * Nr;
301
+ const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
302
+ const int64_t n_size = n_end - n_start;
303
+ const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
304
+ auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks].get();
305
+ for (int64_t other_slice = 1; other_slice < num_Kt_blocks; other_slice++) {
306
+ auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + other_slice].get();
307
+ for (int64_t m = m_offset; m < m_offset + m_size; m++) {
308
+ #pragma omp simd
309
+ for (int64_t n = 0; n < n_size; n++) {
310
+ {{acc_buf_name}}[m*Nr + n] += other_acc[m*Nr + n];
311
+ }
312
+ }
313
+ }
314
+ {%- set tile_acc_m_slice = kernel.slice_nd(tile_acc, [("m_offset", "m_offset + m_end - m_start"), ()]) %}
315
+ {{ kernel.store_output(
316
+ tile_Y, tile_acc_m_slice, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
317
+ )|indent(20, false)
318
+ }}
319
+ }
320
+ }
321
+ }
322
+ {%- endif %}
323
+ {{ micro_gemm.codegen_finalize(kernel) }}
324
+ }
325
+ }
326
+ """
327
+
328
+ SMALL_M_GEMM_TEMPLATE = r"""
329
+ {{ template.codegen_gemm_stub_def() }}
330
+ {
331
+ {{ kernel.maybe_codegen_profile() }}
332
+ {{ template.codegen_blocks(
333
+ num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W
334
+ ) }}
335
+ # pragma omp parallel
336
+ {
337
+ #pragma omp for nowait
338
+ for (int64_t nr_block_id = 0; nr_block_id < Nr_blocks; nr_block_id++) {
339
+ // Handle one output M * Nr block in each thread
340
+ int64_t n_start = nr_block_id * Nr;
341
+ int64_t n_end = (nr_block_id + 1) * Nr;
342
+ {%- if use_local_acc %}
343
+ {%- set acc_buf_name = "local_acc_buf" %}
344
+ {{ kernel.define_stack_allocated_buffer(acc_buf_name, ["M", "Nr"], acc_buf_dtype) }}
345
+ {%- set acc = kernel.local_buffers[acc_buf_name] %}
346
+ {%- else %}
347
+ {%- set acc = kernel.slice_nd(GemmOut, [(0, "M"), ("n_start", "n_end")]) %}
348
+ {%- endif %}
349
+ for (int64_t kr_block_id = 0; kr_block_id < Kr_blocks; kr_block_id++) {
350
+ // this loop is not parallelized
351
+ int64_t k_start = kr_block_id * Kr;
352
+ int64_t k_end = std::min((kr_block_id + 1) * Kr, K);
353
+ {%- set tile_X = kernel.slice_nd(X, [(0, "M"), ("k_start", "k_end")]) %}
354
+ {%- set tile_W_3d = kernel.slice_nd(W, [("nr_block_id", "nr_block_id + 1"), ("k_start", "k_end"), ()]) %}
355
+ {%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
356
+ if C10_UNLIKELY(kr_block_id == 0) {
357
+ {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=False, prefetch=True)|indent(20, false) }}
358
+ } else if C10_UNLIKELY(k_end == K) {
359
+ {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True, prefetch=False)|indent(20, false) }}
360
+ } else {
361
+ {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True, prefetch=True)|indent(20, false) }}
362
+ }
363
+ }
364
+ {%- set tile_Y = kernel.slice_nd(Y_2d, [("0", "M"), ("n_start", "n_end")]) %}
365
+ {%- set tile_acc = kernel.slice_nd(acc, [("0", "M"), ("0", "n_end - n_start")]) %}
366
+ {{ kernel.store_output(
367
+ tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("0", "n_start"), reindexers=reindexers
368
+ )|indent(20, false) }}
369
+ }
370
+ }
371
+ }
372
+ """
373
+
374
+
375
+ def _is_int8_gemm(inputs):
376
+ return (
377
+ isinstance(inputs[0], ir.IRNode)
378
+ and inputs[0].get_dtype() in [torch.uint8, torch.int8]
379
+ ) or (
380
+ isinstance(inputs[0], torch.Tensor)
381
+ and inputs[0].dtype in [torch.uint8, torch.int8]
382
+ )
383
+
384
+
385
+ def get_padded_n(n, block_n):
386
+ return (n + block_n - 1) // block_n * block_n
387
+
388
+
389
+ _T = TypeVar("_T", ir.IRNode, torch.Tensor)
390
+
391
+
392
+ def transpose_w(W: _T, trans_w: bool) -> _T:
393
+ """
394
+ Transpose W based on the trans_w flag.
395
+ """
396
+ if isinstance(W, ir.IRNode):
397
+ if trans_w:
398
+ if not isinstance(W, ir.TensorBox):
399
+ W = ir.TensorBox(W)
400
+ W = L.permute(W, [1, 0])
401
+ else:
402
+ if trans_w:
403
+ assert isinstance(W, torch.Tensor)
404
+ W = W.transpose(0, 1)
405
+ return W
406
+
407
+
408
+ def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]:
409
+ """
410
+ Expand Bias to the same size of X.
411
+ """
412
+ if B is not None:
413
+ if isinstance(B, ir.IRNode):
414
+ if not isinstance(B, ir.TensorBox):
415
+ B = ir.TensorBox(B)
416
+ assert hasattr(X, "get_size")
417
+ B = L.expand(B, (X.get_size()[0], B.get_size()[-1]))
418
+ else:
419
+ assert isinstance(B, torch.Tensor)
420
+ assert isinstance(X, torch.Tensor)
421
+ B = B.expand(X.shape[0], B.shape[-1])
422
+ return B
423
+
424
+
425
+ def prune_tensors(input_nodes: list[ir.IRNode], new_input_nodes: list[ir.IRNode]):
426
+ """
427
+ Prune unused tensors from `V.graph` since the GEMM Template use new packed weight.
428
+ """
429
+
430
+ def share_storage(base_tensor: torch.Tensor, comp_tensor: torch.Tensor):
431
+ return base_tensor.is_mkldnn == comp_tensor.is_mkldnn and (
432
+ is_same_tensor(base_tensor, comp_tensor)
433
+ or is_same_mkldnn_tensor(base_tensor, comp_tensor)
434
+ )
435
+
436
+ def get_candidates(input_nodes, new_input_nodes):
437
+ # Only Constant Buffer like weight and bias might be changed in GEMM Template.
438
+ # The Inductor IR Node may changed, but still share the storage. For example:
439
+ # bias in bfloat16 case which only do the expand
440
+ return [
441
+ node
442
+ for node in input_nodes
443
+ if (
444
+ node not in new_input_nodes
445
+ and isinstance(node, (ir.TensorBox, ir.StorageBox))
446
+ and node.get_name() in V.graph.constants
447
+ and not any(
448
+ (
449
+ isinstance(new_node, (ir.TensorBox, ir.StorageBox))
450
+ and new_node.get_name() in V.graph.constants
451
+ and share_storage(
452
+ V.graph.constants[node.get_name()],
453
+ V.graph.constants[new_node.get_name()],
454
+ )
455
+ )
456
+ for new_node in new_input_nodes
457
+ )
458
+ )
459
+ ]
460
+
461
+ for candidate_node in get_candidates(input_nodes, new_input_nodes):
462
+ # By using the new packed weight for the GEMM template, we can prune the
463
+ # old weight if it has no other users. This saves memory but makes the FX graph
464
+ # non-retraceable. To support retracing, we can add a repack node to the
465
+ # FX graph. For example:
466
+ # mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template
467
+ candidate_tensor_users = 0
468
+ candidate_tensor = V.graph.constants[candidate_node.get_name()]
469
+ for node in reversed(V.graph.graph.nodes):
470
+ # Case may happen when the candidate tensor is used by more than 1 get_attr node
471
+ # https://github.com/pytorch/pytorch/issues/134998
472
+ if node.op == "get_attr" and hasattr(
473
+ V.graph.module, node.target
474
+ ): # candidate tensor might already be deleted
475
+ comp_tensor = getattr(V.graph.module, node.target)
476
+ if isinstance(comp_tensor, torch.Tensor) and share_storage(
477
+ candidate_tensor, comp_tensor
478
+ ):
479
+ candidate_tensor_users += 1
480
+
481
+ for node in reversed(V.graph.graph.nodes):
482
+ # The get_attr node has only 1 user fx node
483
+ # The candidate tensor has been used by only 1 get_attr node
484
+ if (
485
+ node.op == "get_attr"
486
+ and node.target == candidate_node.get_name()
487
+ and len(node.users) == 1
488
+ and candidate_tensor_users == 1
489
+ ):
490
+ del V.graph.constants[node.target]
491
+ delattr(V.graph.module, node.target)
492
+ delattr(V.graph.graph.owning_module, node.target)
493
+ counters["inductor"]["select_algorithm_weight_prune"] += 1
494
+
495
+
496
+ def gen_2d_view_of_epilogue_buf(
497
+ Y: ir.Buffer,
498
+ template_buffer: ir.Buffer,
499
+ epilogue_nodes: list[ir.IRNode],
500
+ reindexers: list[Optional[Callable[[list[Any]], list[Any]]]],
501
+ default_reindexers: list[Optional[Callable[[list[Any]], list[Any]]]],
502
+ ) -> tuple[
503
+ Union[ir.Buffer, ir.ReinterpretView],
504
+ list[Optional[Callable[[list[Any]], list[Any]]]],
505
+ ]:
506
+ """
507
+ The dimension and the indexing could be different between the GEMM output, i.e. `template_buffer`, which is
508
+ 2D with MxN) and the output from the template after epilogues, i.e. `Y`. In the GEMM template code,
509
+ we are not aware of the dimension and the indexing of the epilogues and always work on 2D tiles according to
510
+ the indexing of the GEMM output.
511
+ In this function, we return a 2D buffer (`Y_2d`) according to GEMM output (reinterpreted from `Y` if needed) and
512
+ build a reindexer that converts the indexing of `Y` into `Y_2d`.
513
+ """
514
+ Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
515
+ if (
516
+ Y.get_size() == template_buffer.get_size()
517
+ and Y.get_stride() == template_buffer.get_stride()
518
+ ):
519
+ reindexers.extend(default_reindexers)
520
+ Y_2d = Y
521
+ else:
522
+
523
+ def get_reindexer(epilogue_node, default_reindexer=None):
524
+ # From template_buffer to epilogue_node_ordered (ordered by stride decreasingly, in dense format), for example:
525
+ # template_buffer:
526
+ # size (324, 512), stride (512, 1)
527
+ # epilogue_node_ordered (ordered by stride decreasingly, in dense format):
528
+ # size (1, 18, 18, 512), stride (165888, 9216, 512, 1)
529
+ stride_order = list(
530
+ ir.get_stride_order(
531
+ V.graph.sizevars.size_hints(epilogue_node.get_stride())
532
+ )
533
+ )
534
+ fill_order = ir.stride_order2fill_order(stride_order)
535
+ reversed_fill_order = list(reversed(fill_order))
536
+ size_with_stride_ordered_decreasingly = [
537
+ epilogue_node.get_size()[i] for i in reversed_fill_order
538
+ ]
539
+ reshape_reindex = ir.View.dynamic_reshape_indexer(
540
+ size_with_stride_ordered_decreasingly,
541
+ template_buffer.get_size(),
542
+ )
543
+ if default_reindexer:
544
+ reshape_reindex = ir.fuse_reindexing(reshape_reindex, default_reindexer)
545
+
546
+ # From epilogue_node_ordered (ordered by stride decreasingly, in dense format) to epilogue_node, for example:
547
+ # epilogue_node_ordered (ordered by stride decreasingly, in dense format):
548
+ # size (1, 18, 18, 512), stride (165888, 9216, 512, 1)
549
+ # epilogue_node:
550
+ # size (1, 18, 18, 512), stride (165888, 1, 9216, 512)
551
+ from_stride_ordered_decreasingly_to_epilogue_node_order = [
552
+ (len(stride_order) - 1) - stride_order[i]
553
+ for i in range(len(stride_order))
554
+ ]
555
+ stride_reindex = ir.same_reorder(
556
+ from_stride_ordered_decreasingly_to_epilogue_node_order
557
+ )
558
+
559
+ reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex) # type: ignore[var-annotated]
560
+ return reindexer
561
+
562
+ if default_reindexers is None:
563
+ default_reindexers = [None] * len(epilogue_nodes)
564
+ new_reindexers = [
565
+ get_reindexer(epilogue_node, default_reindexer)
566
+ for epilogue_node, default_reindexer in zip(
567
+ epilogue_nodes, default_reindexers
568
+ )
569
+ ]
570
+ reindexers.extend(new_reindexers)
571
+ if isinstance(Y, ir.BaseView):
572
+ storage = ir.StorageBox(Y.unwrap_view())
573
+ else:
574
+ assert isinstance(Y, ir.Buffer)
575
+ storage = ir.StorageBox(Y)
576
+ Y_2d = ir.ReinterpretView(data=storage, layout=template_buffer.get_layout())
577
+ return Y_2d, reindexers
578
+
579
+
580
+ class CppGemmTemplate(CppTemplate):
581
+ """
582
+ GEMM Template for Inductor CPP Backend.
583
+ """
584
+
585
+ def __init__(
586
+ self,
587
+ input_nodes,
588
+ layout: ir.Layout,
589
+ num_threads: int,
590
+ register_blocking: GemmBlocking,
591
+ beta=1,
592
+ alpha=1,
593
+ has_bias=False,
594
+ epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
595
+ should_block_weights: bool = True,
596
+ name="packed_gemm",
597
+ ) -> None:
598
+ assert layout.dtype in [torch.float, torch.bfloat16, torch.half, torch.uint8]
599
+ super().__init__(
600
+ name,
601
+ input_nodes,
602
+ layout,
603
+ num_threads,
604
+ epilogue_creator=epilogue_creator,
605
+ )
606
+ self.beta = beta
607
+ self.alpha = alpha
608
+ self.has_bias = has_bias
609
+ self.register_blocking = register_blocking
610
+ m, n = layout.size[-2:]
611
+ k = input_nodes[0].get_size()[-1]
612
+ self.m, self.n, self.k = m, n, k
613
+ self.padded_n = get_padded_n(n, self.register_blocking.block_n)
614
+ self.is_dynamic_M = has_free_symbols((m,))
615
+ self.should_block_weights = should_block_weights
616
+ self.thread_blocking = self.make_thread_blocking_cache()
617
+ self.cache_blocking = self.make_cache_blocking_cache()
618
+
619
+ def make_thread_blocking_cache(self):
620
+ cache = lru_cache()(self._thread_blocking)
621
+
622
+ def thread_blocking(num_threads: int) -> GemmBlocking:
623
+ return cache(num_threads)
624
+
625
+ return thread_blocking
626
+
627
+ def _thread_blocking(self, num_threads: int) -> GemmBlocking:
628
+ """
629
+ NOTE [Thread blocking in Cpp GEMM]
630
+ We use simple heuristics to decide the thread blocking:
631
+ 1. Make sure all threads are occupied as much as possible.
632
+ 2. For (m, n) blocks, favor more square-sized thread blocks for better data reuse.
633
+ 3. If (m, n) blocks cannot occupy all the threads, we consider k-slicing.
634
+ TODO(jgong5): allow tuning various blocking options
635
+ """
636
+
637
+ def get_factors(number):
638
+ factors = []
639
+ for i in range(int(number**0.5), 0, -1):
640
+ if number % i == 0:
641
+ factors.append(number // i)
642
+ factors.append(i)
643
+ return factors
644
+
645
+ def get_blocking(m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks):
646
+ thread_block_k = math.ceil(k_blocks / k_factor)
647
+ thread_block_n = math.ceil(n_blocks / n_factor)
648
+ thread_block_m = math.ceil(m_blocks / m_factor)
649
+ return GemmBlocking(thread_block_m, thread_block_n, thread_block_k)
650
+
651
+ assert not self.is_dynamic_M, (
652
+ "Unable to determine thread blocking for dynamic M."
653
+ )
654
+ register_blocking = self.register_blocking
655
+ m_blocks = math.ceil(self.m / register_blocking.block_m)
656
+ n_blocks = math.ceil(self.n / register_blocking.block_n)
657
+ k_blocks = math.ceil(self.k / register_blocking.block_k)
658
+ factors = get_factors(num_threads)
659
+ assert len(factors) > 0
660
+
661
+ if config.cpp.gemm_thread_factors is not None:
662
+ factors = [int(i) for i in config.cpp.gemm_thread_factors.split(",")]
663
+ assert len(factors) == 3
664
+ assert math.prod(factors) == self.num_threads
665
+ return get_blocking(
666
+ factors[0], factors[1], factors[2], m_blocks, n_blocks, k_blocks
667
+ )
668
+
669
+ # we favor square-sized thread blocks for good data reuse
670
+ def get_better_blocking(blocking, best_blocking):
671
+ if best_blocking is None:
672
+ best_blocking = blocking
673
+ else:
674
+ block_m_size = blocking.block_m * register_blocking.block_m
675
+ block_n_size = blocking.block_n * register_blocking.block_n
676
+ best_block_m_size = best_blocking.block_m * register_blocking.block_m
677
+ best_block_n_size = best_blocking.block_n * register_blocking.block_n
678
+ if blocking.block_k > best_blocking.block_k:
679
+ best_blocking = blocking
680
+ elif (
681
+ blocking.block_k == best_blocking.block_k
682
+ and block_m_size + block_n_size
683
+ < best_block_m_size + best_block_n_size
684
+ ):
685
+ best_blocking = blocking
686
+ return best_blocking
687
+
688
+ best_blocking = None
689
+ # check if we can have a thread-blocking to occupy all threads without k-slicing
690
+ for n_factor in factors:
691
+ m_factor = num_threads // n_factor
692
+ if n_blocks >= n_factor and m_blocks >= m_factor:
693
+ blocking = get_blocking(
694
+ m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
695
+ )
696
+ best_blocking = get_better_blocking(blocking, best_blocking)
697
+
698
+ if best_blocking is None:
699
+ for k_factor in factors:
700
+ if k_blocks >= k_factor and (
701
+ config.cpp.gemm_max_k_slices == 0
702
+ or k_factor <= config.cpp.gemm_max_k_slices
703
+ ):
704
+ n_factors = get_factors(num_threads // k_factor)
705
+ for n_factor in n_factors:
706
+ m_factor = (num_threads // k_factor) // n_factor
707
+ if n_blocks >= n_factor and m_blocks >= m_factor:
708
+ blocking = get_blocking(
709
+ m_factor,
710
+ n_factor,
711
+ k_factor,
712
+ m_blocks,
713
+ n_blocks,
714
+ k_blocks,
715
+ )
716
+ best_blocking = get_better_blocking(blocking, best_blocking)
717
+
718
+ if best_blocking is None:
719
+ for n_factor in factors:
720
+ m_factor = num_threads // n_factor
721
+ if n_blocks >= n_factor or m_blocks >= m_factor:
722
+ blocking = get_blocking(
723
+ m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
724
+ )
725
+ best_blocking = get_better_blocking(blocking, best_blocking)
726
+
727
+ assert best_blocking is not None
728
+ return best_blocking
729
+
730
+ def make_cache_blocking_cache(self):
731
+ cache = lru_cache()(self._cache_blocking)
732
+
733
+ def cache_blocking(num_threads: int) -> GemmBlocking:
734
+ return cache(num_threads)
735
+
736
+ return cache_blocking
737
+
738
+ def _cache_blocking(self, num_threads: int) -> GemmBlocking:
739
+ def get_cache_blocking(register_blocking, thread_blocking):
740
+ Mr = register_blocking.block_m
741
+ Nr = register_blocking.block_n
742
+ Kr = register_blocking.block_k
743
+
744
+ Mt_blocks = thread_blocking.block_m
745
+ Nt_blocks = thread_blocking.block_n
746
+ Kt_blocks = thread_blocking.block_k
747
+
748
+ if config.cpp.gemm_cache_blocking is not None:
749
+ blockings = [int(i) for i in config.cpp.gemm_cache_blocking.split(",")]
750
+ assert len(blockings) == 3
751
+ Mc_blocks, Nc_blocks, Kc_blocks = blockings
752
+ return (
753
+ min(Mc_blocks, Mt_blocks),
754
+ min(Nc_blocks, Nt_blocks),
755
+ min(Kc_blocks, Kt_blocks),
756
+ )
757
+
758
+ # The ratios below are empirically determined to decide
759
+ # the effective sizes of L1 and L2.
760
+ # TODO: tune the factor here
761
+ L1_limit_factor = 0.8
762
+ L2_limit_factor = 0.5
763
+
764
+ L1_cache_size = (
765
+ torch._C._cpu._L1d_cache_size()
766
+ ) # per core cache size in Bytes
767
+ assert L1_cache_size > 0, (
768
+ f"Expect L1_cache_size > 0 but got {L1_cache_size}"
769
+ )
770
+ L1 = L1_cache_size * L1_limit_factor
771
+
772
+ L2_cache_size = (
773
+ torch._C._cpu._L2_cache_size()
774
+ ) # per core cache size in Bytes
775
+ assert L2_cache_size > 0, (
776
+ f"Expect L2_cache_size > 0 but got {L2_cache_size}"
777
+ )
778
+ L2 = L2_cache_size * L2_limit_factor
779
+
780
+ def get_num_byte(dtype):
781
+ return torch.tensor([], dtype=dtype).element_size()
782
+
783
+ dtype_A = self.input_nodes[0].get_dtype()
784
+ dtype_B = self.input_nodes[1].get_dtype()
785
+ num_byte_A = get_num_byte(dtype_A)
786
+ num_byte_B = get_num_byte(dtype_B)
787
+ if dtype_A is torch.bfloat16 and dtype_B is torch.int8 and Kr != 1:
788
+ # We will cache dequantized weights (BF16) in L1D for AMX micro-kernel.
789
+ # In this case, the choice of the micro-kernel being used can't be decoupled from
790
+ # the cache blocking.
791
+ # TODO: Decouple the choice of micro-kernel from cache blocking
792
+ num_byte_B *= num_byte_A
793
+
794
+ # NOTE [CPP GEMM Cache Blocking Algorithm]
795
+ # Our overall strategy is to
796
+ # 1) Make cache blocks of B L1-reside and reused by multiple rows of A, i.e. Mc.
797
+ # Here, B is Kc x Nr where Nr is a single register block. We use L1 size to
798
+ # decide Kc. We want to make Mc large enough to better reuse B.
799
+ # 2) Make cache blocks of A L2-reside, which would limit Mc. We want to reuse A
800
+ # along N, where we have two sub-strategies (see notes below) to decide Mc and Nc.
801
+
802
+ # Step 1: Decide Kc assuming B block is L1-reside.
803
+ size_cache_B = Kr * Kt_blocks * Nr * num_byte_B
804
+
805
+ Kc_blocks = Kt_blocks
806
+ if size_cache_B > L1:
807
+ Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B))
808
+
809
+ if (
810
+ config.cpp.use_small_dequant_buffer
811
+ and dtype_A is torch.bfloat16
812
+ and dtype_B is torch.uint8
813
+ and Mt_blocks == 1
814
+ ):
815
+ # Make a small dequant_B buffer for woq int4 [q_group_size, Nr]
816
+ # Since when Mt_blocks == 1, L1-reside B block can't be reused by A.
817
+ if Kc_blocks * Kr >= self.q_group_size():
818
+ Kc_blocks = self.q_group_size() // Kr
819
+
820
+ # Step 2: Decide Mc assuming A block is L2-reside.
821
+ min_Mc_ratio = 2 # TODO(jgong5): something to tune?
822
+ min_Mc_blocks = math.ceil(min_Mc_ratio * Mr / Nr)
823
+ assert min_Mc_blocks >= 1
824
+ Kt_bytes = Kt_blocks * Kr * num_byte_A
825
+ if min_Mc_blocks * Mr * Kt_bytes < L2:
826
+ # Strategy 1: A (Mc x Kt) resides in L2 and reused by all Nt
827
+ # when Nc_blocks is kept 1. Mc should be large enough (>= min_Mc_blocks)
828
+ # to reuse B (Kc x Nr) in L1. This makes C (Mc x Nr) small enough to reside
829
+ # in L1.
830
+ Mc_blocks = min(Mt_blocks, math.floor(L2 / (Mr * Kt_bytes)))
831
+ Nc_blocks = 1
832
+ else:
833
+ # Strategy 2: Kt is too large to hold A (Mc x Kt) in L2, we reuse
834
+ # A (Mc x Kc) in L2 by B (Kc x Nc). C (Mc x Nc) resides in L2.
835
+ Mc_blocks = Mt_blocks
836
+ Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks)
837
+ Nc_bytes = Nc_blocks * Nr * 4 # assume C or acc is float32/int32
838
+ Kc_bytes = Kc_blocks * Kr * num_byte_A
839
+ if Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2:
840
+ # The following is the solution for 4*Mc*Nc + Mc*Kc_bytes = L2,
841
+ # assuming Mc == Nc for good data reuse.
842
+ M_max = (math.sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8
843
+ if M_max < Mc_blocks * Mr:
844
+ Mc_blocks = math.floor(M_max / Mr)
845
+ Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks)
846
+
847
+ return Mc_blocks, Nc_blocks, Kc_blocks
848
+
849
+ assert not self.is_dynamic_M, (
850
+ "Unable to determine cache blocking for dynamic M."
851
+ )
852
+ register_blocking = self.register_blocking
853
+ thread_blocking = self.thread_blocking(num_threads)
854
+
855
+ return GemmBlocking(*get_cache_blocking(register_blocking, thread_blocking))
856
+
857
+ def log_blockings(self):
858
+ log.debug(f"Register blocking: {self.register_blocking}") # noqa: G004
859
+ if self.is_dynamic_M:
860
+ # thread and cache blockings are determined at runtime for dynamic shapes
861
+ return
862
+ log.debug(
863
+ f"Cache blocking: {self.cache_blocking(self.num_threads)}" # noqa: G004
864
+ )
865
+ thread_blocking = self.thread_blocking(self.num_threads)
866
+ log.debug(f"Thread blocking: {thread_blocking}") # noqa: G004
867
+
868
+ def get_occupancy():
869
+ m_blocks = math.ceil(self.m / self.register_blocking.block_m)
870
+ n_blocks = math.ceil(self.n / self.register_blocking.block_n)
871
+ k_blocks = math.ceil(self.k / self.register_blocking.block_k)
872
+ m = math.ceil(m_blocks / thread_blocking.block_m)
873
+ n = math.ceil(n_blocks / thread_blocking.block_n)
874
+ k = math.ceil(k_blocks / thread_blocking.block_k)
875
+ return (m, n, k)
876
+
877
+ log.debug(
878
+ f"Number of threads: {self.num_threads}, occupancy: {get_occupancy()}" # noqa: G004
879
+ )
880
+
881
+ def maybe_k_slicing(self):
882
+ if self.num_threads == 1:
883
+ return False
884
+ if self.is_dynamic_M:
885
+ # TODO(jgong5): perhaps use size hint to decide?
886
+ return True
887
+ register_blocking = self.register_blocking
888
+ k_blocks = math.ceil(self.k / register_blocking.block_k)
889
+ thread_blocking = self.thread_blocking(self.num_threads)
890
+ return k_blocks > thread_blocking.block_k
891
+
892
+ @classmethod
893
+ def add_choices(
894
+ cls,
895
+ choices,
896
+ layout,
897
+ input_nodes,
898
+ beta=1,
899
+ alpha=1,
900
+ has_bias=False,
901
+ trans_w=False,
902
+ input_indices=None,
903
+ epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
904
+ act_mapping: Optional[dict[int, ir.IRNode]] = None,
905
+ ):
906
+ """
907
+ Add choices for the GEMM template.
908
+ """
909
+ # Fast path to save the epilogue calculation when x_scale/x_zp/w_scale are constant
910
+ use_int8_fast_compensation_path = _is_int8_gemm(input_nodes) and all(
911
+ (
912
+ isinstance(input_nodes[idx], ir.TensorBox)
913
+ and isinstance(input_nodes[idx].data.data, ir.ConstantBuffer)
914
+ )
915
+ for idx in [1, 2, 4]
916
+ )
917
+
918
+ if input_indices is None:
919
+ input_indices = list(range(len(input_nodes)))
920
+ only_one_input = (
921
+ input_nodes[0] == input_nodes[1] if len(input_nodes) > 1 else False
922
+ )
923
+
924
+ def reorder_and_filter(inputs, layout_or_out):
925
+ if has_bias:
926
+ assert len(input_indices) >= 3
927
+ # Assume the input order is [inp, x, w] and we reorder it to [x, w, inp]
928
+ inp_idx = input_indices[0]
929
+ x_idx = input_indices[1]
930
+ w_idx = input_indices[2]
931
+ return [
932
+ inputs[x_idx],
933
+ inputs[w_idx],
934
+ inputs[inp_idx],
935
+ *[inputs[idx] for idx in input_indices[3:]],
936
+ ], layout_or_out
937
+ elif len(inputs) >= len(input_indices):
938
+ assert len(input_indices) >= 2
939
+ return [inputs[idx] for idx in input_indices], layout_or_out
940
+ else:
941
+ # For when input is used for x and w, i.e. X@X.T or similar
942
+ # Assumes the first input is the only input
943
+ assert len(inputs) == 1
944
+ return [inputs[0]] * len(input_indices), layout_or_out
945
+
946
+ new_inputs, new_layout = reorder_and_filter(input_nodes, layout)
947
+ is_mkldnn_wgt = (
948
+ new_inputs[1].get_name() in V.graph.constants
949
+ and V.graph.constants[new_inputs[1].get_name()].is_mkldnn
950
+ )
951
+ if is_mkldnn_wgt:
952
+ # It shouldn't happen as viewing an mkldnn tensor, we can extend the
953
+ # implementation if it does.
954
+ assert not isinstance(new_inputs[1], ir.BaseView)
955
+ # Note that the layout of MKLDNN Tensor is with the wrong stride
956
+ view_size = new_inputs[1].layout.size
957
+ view_stride = new_inputs[1].layout.stride
958
+ view_offset = new_inputs[1].layout.offset
959
+
960
+ def maybe_to_dense(inputs, layout_or_out):
961
+ new_inputs = list(inputs)
962
+ if isinstance(inputs[1], torch.Tensor):
963
+ W = inputs[1]
964
+ new_inputs[1] = W.to_dense() if W.is_mkldnn else W
965
+ return new_inputs, layout_or_out
966
+
967
+ def normalize_shapes(inputs, layout_or_out):
968
+ new_inputs = list(inputs)
969
+ if not is_mkldnn_wgt and isinstance(new_inputs[1], torch.Tensor):
970
+ if has_free_symbols(view_size):
971
+ # If batch size B is dynamic, we need to set the batch size and possibly stride
972
+ assert not has_free_symbols(view_size[1:])
973
+ view_size[:] = V.graph.sizevars.size_hints(view_size)
974
+ view_stride[:] = V.graph.sizevars.size_hints(view_stride)
975
+ # With the assumptation that W is the storage of unwrap view
976
+ # thus view it back here
977
+ new_inputs[1] = new_inputs[1].as_strided(
978
+ view_size, view_stride, view_offset
979
+ )
980
+
981
+ if not trans_w:
982
+ return new_inputs, layout_or_out
983
+ X = new_inputs[0]
984
+ W = new_inputs[1]
985
+ B = new_inputs[2] if has_bias else None
986
+ W = transpose_w(W, trans_w)
987
+ B = expand_bias(B, X) # type:ignore[arg-type]
988
+ new_inputs[1] = W
989
+ if B is not None:
990
+ new_inputs[2] = B
991
+ return new_inputs, layout_or_out
992
+
993
+ # TODO(jgong5): decide proper number of threads per problem size
994
+ num_threads = parallel_num_threads()
995
+ new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout))
996
+ m, n, k, *_ = mm_args(
997
+ new_inputs[0],
998
+ new_inputs[1],
999
+ mat2_transposed=cls.is_woq_int4(),
1000
+ use_4x2_dim=cls.is_woq_int4(),
1001
+ )
1002
+ output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
1003
+ new_inputs[0].get_dtype()
1004
+ )
1005
+ micro_gemm = create_micro_gemm(
1006
+ "micro_gemm",
1007
+ m,
1008
+ n,
1009
+ k,
1010
+ input_dtype=new_inputs[0].get_dtype(),
1011
+ input2_dtype=new_inputs[1].get_dtype(),
1012
+ output_dtype=output_dtype,
1013
+ compute_dtype=compute_dtype,
1014
+ alpha=alpha,
1015
+ num_threads=num_threads,
1016
+ use_ref=not cls.is_woq_int4(),
1017
+ q_group_size=cls.q_group_size(),
1018
+ )
1019
+ assert micro_gemm is not None
1020
+ pre_block_weights = cls.check_if_block_weight(new_inputs[1], micro_gemm)
1021
+ micro_gemm.use_local_vnni_blocking(not pre_block_weights)
1022
+
1023
+ def preprocessor(inputs, layout):
1024
+ new_inputs, new_layout = normalize_shapes(
1025
+ *maybe_to_dense(*reorder_and_filter(inputs, layout))
1026
+ )
1027
+ if only_one_input and isinstance(new_inputs[0], torch.Tensor):
1028
+ return new_inputs[1:], new_layout
1029
+ return cls.prep_weight(
1030
+ new_inputs,
1031
+ new_layout,
1032
+ micro_gemm,
1033
+ pre_block_weights,
1034
+ use_int8_fast_compensation_path,
1035
+ )
1036
+
1037
+ def postprocessor(output):
1038
+ if isinstance(output, ir.TensorBox):
1039
+ # prepack the weight as input to the template buffer
1040
+ template_buffer = ir.InputsKernel.unwrap_storage_for_input(output)
1041
+ assert isinstance(template_buffer, ir.CppTemplateBuffer)
1042
+ new_input_nodes, _ = reorder_and_filter(input_nodes, layout)
1043
+
1044
+ W_node = new_input_nodes[1]
1045
+ if W_node.get_name() not in V.graph.constants:
1046
+ return output
1047
+ W = V.graph.constants[W_node.get_name()]
1048
+ new_input_nodes[1] = W
1049
+ new_input_nodes, new_layout = normalize_shapes(
1050
+ *maybe_to_dense(new_input_nodes, layout)
1051
+ )
1052
+ new_input_nodes, _ = cls.prep_weight(
1053
+ new_input_nodes,
1054
+ new_layout,
1055
+ micro_gemm,
1056
+ pre_block_weights,
1057
+ use_int8_fast_compensation_path,
1058
+ skip_int8_compensation=True,
1059
+ )
1060
+ W_packed = new_input_nodes[1]
1061
+ W_packed_constant = V.graph.add_tensor_constant(W_packed)
1062
+ new_input_nodes[1] = W_packed_constant
1063
+
1064
+ # Prune unused tensors
1065
+ prune_tensors(input_nodes, new_input_nodes)
1066
+
1067
+ template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input(
1068
+ W_packed_constant
1069
+ )
1070
+ return output
1071
+
1072
+ template = DataProcessorTemplateWrapper(
1073
+ cls,
1074
+ preprocessor,
1075
+ postprocessor,
1076
+ input_nodes=input_nodes,
1077
+ layout=layout,
1078
+ num_threads=num_threads,
1079
+ register_blocking=micro_gemm.register_blocking,
1080
+ beta=beta,
1081
+ alpha=alpha,
1082
+ has_bias=has_bias,
1083
+ epilogue_creator=epilogue_creator,
1084
+ should_block_weights=pre_block_weights,
1085
+ name=micro_gemm.__class__.__name__,
1086
+ )
1087
+ template.maybe_append_choice(choices)
1088
+ return template
1089
+
1090
+ @staticmethod
1091
+ def get_padded_size(n, block_n, k, should_block_weight):
1092
+ padded_n = get_padded_n(n, block_n)
1093
+ # We assume that all GEMM weight tensors should be blocked and padded
1094
+ new_size = [padded_n // block_n, k, block_n]
1095
+ return new_size, padded_n
1096
+
1097
+ @classmethod
1098
+ def prep_weight(
1099
+ cls,
1100
+ inputs,
1101
+ layout: ir.Layout,
1102
+ micro_gemm: CppMicroGemm,
1103
+ should_block_weight: bool,
1104
+ use_int8_fast_compensation_path: bool = False,
1105
+ skip_int8_compensation: bool = False,
1106
+ ):
1107
+ """
1108
+ NOTE Weight prep consists of 2 separate steps:
1109
+ 1. Blocking the weight tensor into a 3D shape: [n//block_n, k, block_n]
1110
+ This is always done if the weight tensor is constant, i.e. for all GEMM and some BMM.
1111
+ For BMM, we also block non-contiguous weight tensors, since they would be reshaped anyway.
1112
+ This assumes that blocked, contiguous weights will be more efficient for the GEMM kernel,
1113
+ and is worth the overhead of reshape and blocking.
1114
+
1115
+ This blocking includes additional padding, when n is not a multiple of block_n.
1116
+ This padding allows a more efficient microkernel implementation. For BMM, this is only done
1117
+ if reshape would happen anyway, i.e. if the weight tensor is constant, is not contiguous,
1118
+ or is using AMX VNNI layout.
1119
+ 2. Packing the weight tensor into a VNNI-friendly shape. For constant input,
1120
+ this is done at the same time as the weight blocking.
1121
+
1122
+ At compile time, the constant weight tensors are blocked and packed. For non-constant tensors (e.g. BMM)
1123
+ which will be blocked (non-contiguous or VNNI-layout tensors), the weight tensor is blocked and packed at runtime.
1124
+
1125
+ CppBmmTemplate overrides the methods get_padded_size, and block_weight in order to accommodate
1126
+ an additional dimension for the batch size and to determine if the weight tensor should be blocked.
1127
+ """
1128
+ W = inputs[1]
1129
+ new_inputs = list(inputs)
1130
+ if cls.is_woq_int4():
1131
+ assert (
1132
+ len(W.get_size()) == 2
1133
+ if isinstance(W, ir.IRNode)
1134
+ else len(W.shape) == 2
1135
+ )
1136
+ n, k = W.get_size() if isinstance(W, ir.IRNode) else W.shape
1137
+ else:
1138
+ k, n = W.get_size()[-2:] if isinstance(W, ir.IRNode) else W.shape[-2:]
1139
+ _, block_n, _ = micro_gemm.register_blocking
1140
+ new_size, padded_n = cls.get_padded_size(n, block_n, k, should_block_weight)
1141
+ padding = padded_n - n
1142
+
1143
+ if should_block_weight and not cls.is_woq_int4():
1144
+ blocked_w = cls.block_weight(W, new_size, padding)
1145
+ new_inputs[1] = cls.pack_vnni_weight(blocked_w, micro_gemm, new_size)
1146
+ elif should_block_weight:
1147
+ assert cls.is_woq_int4()
1148
+ new_inputs[1] = cls.block_weight(W, new_size, padding)
1149
+ elif isinstance(W, ir.IRNode):
1150
+ # Require W layout to be fixed & contiguous, happens inplace.
1151
+ ir.ExternKernel.require_contiguous(W)
1152
+
1153
+ if not skip_int8_compensation and _is_int8_gemm(new_inputs):
1154
+ BCompensate = None
1155
+ x_w_scale = None
1156
+
1157
+ def _get_compensation_node(W, use_int8_fast_compensation_path):
1158
+ BCompensate = V.graph.add_tensor_constant(
1159
+ V.graph.constants[W.get_name() + "_BMatrixCompens"],
1160
+ W.get_name() + "_BMatrixCompens",
1161
+ )
1162
+ x_w_scale = None
1163
+ if use_int8_fast_compensation_path:
1164
+ x_w_scale = V.graph.add_tensor_constant(
1165
+ V.graph.constants[W.get_name() + "_x_w_compens"],
1166
+ W.get_name() + "_x_w_compens",
1167
+ )
1168
+ return BCompensate, x_w_scale
1169
+
1170
+ if use_int8_fast_compensation_path:
1171
+ # new_inputs has been reordered: [x, w, optional[bias], x_scale, x_zp, w_scale, w_zp]
1172
+ x_scale = new_inputs[-4]
1173
+ x_zp = new_inputs[-3]
1174
+ w_scale = new_inputs[-2]
1175
+ if isinstance(W, ir.IRNode):
1176
+ BCompensate, x_w_scale = _get_compensation_node(
1177
+ W, use_int8_fast_compensation_path
1178
+ )
1179
+ else:
1180
+ # Use the original W, not the blocked_w in new_inputs[1] to calculate BCompensate
1181
+ BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment]
1182
+ assert all(
1183
+ isinstance(item, torch.Tensor)
1184
+ for item in (x_scale, x_zp, w_scale)
1185
+ )
1186
+ BCompensate = BCompensate * x_scale * w_scale * x_zp
1187
+ x_w_scale = x_scale * w_scale
1188
+ new_inputs.append(BCompensate)
1189
+ new_inputs.append(x_w_scale)
1190
+ else:
1191
+ if isinstance(W, ir.IRNode):
1192
+ BCompensate, _ = _get_compensation_node(
1193
+ W, use_int8_fast_compensation_path
1194
+ )
1195
+ else:
1196
+ # Use the original W, not the blocked_w in new_inputs[1] to calculate BCompensate
1197
+ BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment]
1198
+ new_inputs.append(BCompensate)
1199
+ return new_inputs, layout
1200
+
1201
+ @staticmethod
1202
+ def check_if_block_weight(W, micro_gemm):
1203
+ return True
1204
+
1205
+ @classmethod
1206
+ def block_weight(cls, W, new_size, padding):
1207
+ # These are separated into two methods to allow subclasses to override them separately
1208
+ if isinstance(W, ir.IRNode):
1209
+ if W.get_name() in V.graph.constants:
1210
+ # Create a new buffer, representing the constant blocked tensor
1211
+ blocked_w = ir.Buffer(
1212
+ name=W.get_name(), # Borrow the registered buffer name
1213
+ layout=ir.FixedLayout(
1214
+ W.get_device_or_error(),
1215
+ W.get_dtype(),
1216
+ new_size,
1217
+ ir.FlexibleLayout.contiguous_strides(new_size),
1218
+ 0,
1219
+ ),
1220
+ )
1221
+ else:
1222
+ if not isinstance(W, ir.TensorBox):
1223
+ W = ir.TensorBox(W)
1224
+ permute_dims = list(range(len(new_size)))
1225
+ permute_dims[-2], permute_dims[-3] = permute_dims[-3], permute_dims[-2]
1226
+ permute_size = list(new_size)
1227
+ permute_size[-2], permute_size[-3] = permute_size[-3], permute_size[-2]
1228
+ blocked_w = L.constant_pad_nd(W, (0, padding))
1229
+ blocked_w = L.permute(
1230
+ L.view(blocked_w, permute_size),
1231
+ permute_dims,
1232
+ )
1233
+ else:
1234
+ assert isinstance(W, torch.Tensor)
1235
+ # Pad the weight tensor and reshape it into a 3D blocked shape
1236
+ blocked_size = list(new_size)
1237
+ blocked_size[-2], blocked_size[-3] = blocked_size[-3], blocked_size[-2]
1238
+ blocked_w = (
1239
+ torch.nn.functional.pad(W, (0, padding)) # type: ignore[assignment]
1240
+ .reshape(*blocked_size)
1241
+ .transpose(-3, -2)
1242
+ .contiguous()
1243
+ )
1244
+ return blocked_w
1245
+
1246
+ @classmethod
1247
+ def pack_vnni_weight(cls, W, micro_gemm, new_size):
1248
+ # WOQ INT4 weights are reordered in microkernel so do not pack them here
1249
+ should_pack = (
1250
+ micro_gemm.get_b_layout() != LayoutType.NORMAL
1251
+ and not micro_gemm.is_woq_int4()
1252
+ )
1253
+
1254
+ # These are separated into two methods to allow subclasses to override them separately
1255
+ if isinstance(W, ir.IRNode):
1256
+ if isinstance(W, ir.Buffer) and W.get_name() in V.graph.constants:
1257
+ return W
1258
+ k = new_size[-2]
1259
+ if not isinstance(W, ir.TensorBox):
1260
+ W = ir.TensorBox(W)
1261
+ if should_pack:
1262
+ permute_dims = list(range(len(new_size) + 1))
1263
+ permute_dims[-1], permute_dims[-2] = permute_dims[-2], permute_dims[-1]
1264
+ vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
1265
+ vnni_view_size = list(new_size)
1266
+ vnni_view_size[-2] = k // vnni_size
1267
+ vnni_view_size.insert(-1, vnni_size)
1268
+ W = L.view(
1269
+ L.permute(L.view(W, vnni_view_size), permute_dims),
1270
+ new_size,
1271
+ )
1272
+ W = ir.ExternKernel.realize_input(W)
1273
+ W = ir.ExternKernel.require_contiguous(W)
1274
+ return W
1275
+ else:
1276
+ k = new_size[-2]
1277
+ # Apply VNNI packing to the weight tensor
1278
+ if should_pack:
1279
+ # TODO: Move VNNI weight packing for non-constant tensors into the template,
1280
+ # to improve cache locality and avoid full-tensor copy.
1281
+ layout_str = (
1282
+ "VNNI4"
1283
+ if micro_gemm.get_b_layout() == LayoutType.VNNI4
1284
+ else "VNNI2"
1285
+ )
1286
+ assert micro_gemm.get_b_layout() in [
1287
+ LayoutType.VNNI2,
1288
+ LayoutType.VNNI4,
1289
+ ], f"We only support {layout_str} for now"
1290
+ vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
1291
+ assert k % vnni_size == 0, (
1292
+ f"k should be divisible by vnni_size for {layout_str} layout"
1293
+ )
1294
+ vnni_view_size = list(new_size)
1295
+ vnni_view_size[-2] = k // vnni_size
1296
+ vnni_view_size.insert(-1, vnni_size)
1297
+ W = W.view(vnni_view_size).transpose(-1, -2).contiguous().view(new_size)
1298
+ # normalize stride to be "contiguous_strides" per size
1299
+ # this avoids the problems in L.view during template codegen
1300
+ new_stride = [1]
1301
+ for sz in reversed(W.shape[1:]):
1302
+ new_stride.insert(0, new_stride[0] * sz)
1303
+ W = W.as_strided(W.shape, new_stride)
1304
+ return W
1305
+
1306
+ def get_default_reindexers(self, epilogue_nodes):
1307
+ return [None] * len(epilogue_nodes)
1308
+
1309
+ def get_options(
1310
+ self,
1311
+ kernel: CppTemplateKernel,
1312
+ template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
1313
+ flag_template_buffer_has_other_users: Optional[bool] = None,
1314
+ epilogue_nodes: Optional[list[ir.IRNode]] = None,
1315
+ ) -> dict[str, Any]:
1316
+ assert len(self.input_nodes) >= 2
1317
+
1318
+ int8_gemm = self.input_nodes[0].get_dtype() in [torch.uint8, torch.int8]
1319
+ x_scale = None
1320
+ x_zp = None
1321
+ w_scale = None
1322
+ w_zp = None
1323
+ inp = None
1324
+ q_group_size_node = None
1325
+ qscale_and_zeros = None
1326
+ if int8_gemm:
1327
+ X, W = self.input_nodes[0], self.input_nodes[1]
1328
+ bias_idx = 2 if self.has_bias else 1
1329
+ inp = self.input_nodes[bias_idx] if self.has_bias else None
1330
+ x_scale = self.input_nodes[bias_idx + 1]
1331
+ x_zp = self.input_nodes[bias_idx + 2]
1332
+ w_scale = self.input_nodes[bias_idx + 3]
1333
+ w_zp = self.input_nodes[bias_idx + 4]
1334
+ Y = self.output_node
1335
+ elif self.is_woq_int4():
1336
+ X, W = self.input_nodes[0], self.input_nodes[1]
1337
+ Y = self.output_node
1338
+ q_group_size_node = self.input_nodes[2]
1339
+ qscale_and_zeros = self.input_nodes[3]
1340
+ else:
1341
+ X, W = self.input_nodes[0], self.input_nodes[1]
1342
+ Y = self.output_node
1343
+ inp = self.input_nodes[2] if self.has_bias else None
1344
+
1345
+ template_buffer_has_other_users = None
1346
+
1347
+ if template_buffer_node is not None:
1348
+ # Use the updated prepacked weight buffer
1349
+ W = template_buffer_node.inputs[1]
1350
+ Y = template_buffer_node
1351
+
1352
+ assert flag_template_buffer_has_other_users is not None
1353
+ template_buffer_has_other_users = flag_template_buffer_has_other_users
1354
+
1355
+ template_buffer = Y
1356
+ gemm_output_buffer = template_buffer
1357
+
1358
+ epilogues: list[ir.IRNode] = []
1359
+ reindexers: list[Optional[Callable[[list[Any]], list[Any]]]] = []
1360
+ epilogue_creators: list[Callable[[ir.Buffer], ir.Pointwise]] = []
1361
+ fake_buffers: list[ir.Buffer] = []
1362
+ Y_aliases: OrderedSet[str] = OrderedSet()
1363
+
1364
+ use_local_acc = (
1365
+ self.layout.dtype != torch.float
1366
+ or template_buffer_has_other_users
1367
+ or int8_gemm
1368
+ or self.padded_n != self.n
1369
+ or self.maybe_k_slicing()
1370
+ or (epilogue_nodes and epilogue_nodes[-1].get_dtype() != self.layout.dtype)
1371
+ )
1372
+
1373
+ # TODO(jgong5): for int8 gemm, bias-add is handled outside of gemm template,
1374
+ # but we'd better move it here to align with fp.
1375
+ if inp is not None and self.beta != 0 and not int8_gemm:
1376
+ # add an epilogue for bias add
1377
+ def _bias_add_epilogue(buf):
1378
+ return create_epilogue_with_attr(
1379
+ buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype
1380
+ )
1381
+
1382
+ epilogue_creators.append(_bias_add_epilogue)
1383
+
1384
+ if self.epilogue_creator is not None:
1385
+ epilogue_creators.append(self.epilogue_creator)
1386
+
1387
+ # When the GEMM output buffer is localized but it has users other than the epilogue nodes,
1388
+ # we need to copy the value in the GEMM output local buffer to a global buffer.
1389
+ def need_copy_from_local_to_global_buffer_epilogue(
1390
+ use_local_acc, template_buffer_has_other_users, epilogue_creators
1391
+ ):
1392
+ # The GEMM output buffer is a global buffer, thus copy is not needed.
1393
+ if not use_local_acc:
1394
+ return False
1395
+
1396
+ # The possible value of template_buffer_has_other_users is (None, False, True)
1397
+ # It is None when generating the gemm template during autotune and it will have value during scheduler codegen.
1398
+ # extra copy_from_local_to_global_buffer_epilogue is not needed in either of the below two cases:
1399
+ # 1. template_buffer_has_other_users is None (i.e. when doing the codegen during autotune)
1400
+ # 2. template_buffer_has_other_users is False, which means it's safe to keep the value in the
1401
+ # GEMM output buffer in local buffer only (no users outside of the epilogues will use its value).
1402
+ if not template_buffer_has_other_users:
1403
+ return False
1404
+
1405
+ # When bias is not None or self.epilogue_creator is not None,
1406
+ # there will be epilogue_creators after the GEMM.
1407
+ # The GEMM output buffer is localized while
1408
+ # the output buffer of the epilogue_creators is a global buffer.
1409
+ if epilogue_creators:
1410
+ return False
1411
+
1412
+ return True
1413
+
1414
+ if need_copy_from_local_to_global_buffer_epilogue(
1415
+ use_local_acc, template_buffer_has_other_users, epilogue_creators
1416
+ ):
1417
+
1418
+ def copy_from_local_to_global_buffer_epilogue(input_buffer: ir.Buffer):
1419
+ dtype = self.layout.dtype
1420
+ input_loader = input_buffer.make_loader()
1421
+
1422
+ def copy_inner(index):
1423
+ input = input_loader(index)
1424
+ result = ops.to_dtype(input, dtype)
1425
+ return result
1426
+
1427
+ return ir.Pointwise(
1428
+ device=input_buffer.get_device_or_error(),
1429
+ dtype=self.layout.dtype,
1430
+ inner_fn=copy_inner,
1431
+ ranges=input_buffer.get_size(),
1432
+ )
1433
+
1434
+ epilogue_creators.append(copy_from_local_to_global_buffer_epilogue)
1435
+
1436
+ # NOTE [How CPP GEMM template epilogues are organized]
1437
+ # gemm_output_buffer
1438
+ # --> zero or more in-template epilogues (created by `epilogue_creators`) -->
1439
+ # template_buffer
1440
+ # --> zero or more out-of-template epilogues (`epilogue_nodes`) -->
1441
+ # Y
1442
+ if epilogue_creators:
1443
+ assert isinstance(template_buffer, ir.IRNode)
1444
+ gemm_output_name = f"{template_buffer.get_name()}_GemmOut"
1445
+ gemm_output_buffer = ir.Buffer(
1446
+ name=gemm_output_name, layout=template_buffer.layout
1447
+ )
1448
+ current_input_buffer = gemm_output_buffer
1449
+ for i, creator in enumerate(epilogue_creators):
1450
+ if i == len(epilogue_creators) - 1:
1451
+ buffer_name = template_buffer.get_name()
1452
+ else:
1453
+ buffer_name = f"{gemm_output_name}_epilogue_{i}"
1454
+ epilogues.append(
1455
+ ir.ComputedBuffer(
1456
+ name=buffer_name,
1457
+ layout=template_buffer.layout,
1458
+ data=creator(current_input_buffer),
1459
+ )
1460
+ )
1461
+ fake_buffers.append(current_input_buffer)
1462
+ Y_aliases.add(current_input_buffer.get_name())
1463
+ reindexers.append(None)
1464
+ if i < len(epilogue_creators) - 1:
1465
+ current_input_buffer = ir.Buffer(
1466
+ name=buffer_name, layout=template_buffer.layout
1467
+ )
1468
+
1469
+ assert isinstance(Y, (ir.Buffer, ir.ReinterpretView))
1470
+ Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
1471
+
1472
+ if epilogue_nodes:
1473
+ if not template_buffer_has_other_users:
1474
+ assert isinstance(template_buffer, ir.IRNode)
1475
+ Y_aliases.add(template_buffer.get_name())
1476
+ epilogues.extend(epilogue_nodes)
1477
+ assert Y.get_numel() == epilogues[-1].get_numel()
1478
+ Y = cast(ir.Buffer, epilogues[-1])
1479
+ assert isinstance(template_buffer, ir.Buffer)
1480
+ Y_2d, reindexers = gen_2d_view_of_epilogue_buf(
1481
+ Y,
1482
+ template_buffer,
1483
+ epilogue_nodes,
1484
+ reindexers,
1485
+ default_reindexers=self.get_default_reindexers(epilogue_nodes),
1486
+ )
1487
+
1488
+ output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
1489
+ X.get_dtype()
1490
+ )
1491
+ micro_gemm = create_micro_gemm(
1492
+ f"{kernel.kernel_name}_micro_gemm",
1493
+ self.m,
1494
+ self.n,
1495
+ self.k,
1496
+ input_dtype=X.get_dtype(),
1497
+ input2_dtype=W.get_dtype(),
1498
+ output_dtype=output_dtype,
1499
+ compute_dtype=compute_dtype,
1500
+ alpha=self.alpha,
1501
+ num_threads=self.num_threads,
1502
+ use_ref=not self.is_woq_int4(),
1503
+ q_group_size=self.q_group_size(),
1504
+ )
1505
+ assert micro_gemm is not None
1506
+ micro_gemm.use_local_vnni_blocking(not self.should_block_weights)
1507
+ assert self.register_blocking == micro_gemm.register_blocking
1508
+ self.log_blockings()
1509
+ if isinstance(micro_gemm, CppMicroGemmAMX):
1510
+ counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
1511
+ if isinstance(micro_gemm, CppMicroBrgemm):
1512
+ counters["inductor"]["cpp_micro_brgemm_counter"] += 1
1513
+
1514
+ L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes
1515
+ assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
1516
+
1517
+ L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes
1518
+ assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
1519
+
1520
+ options = dict(
1521
+ X=X,
1522
+ W=W,
1523
+ inp=inp,
1524
+ Y=Y,
1525
+ N=self.n,
1526
+ K=self.k,
1527
+ PADDED_N=self.padded_n,
1528
+ GemmOut=gemm_output_buffer,
1529
+ aliases={alias: Y.get_name() for alias in Y_aliases},
1530
+ beta=self.beta,
1531
+ alpha=self.alpha,
1532
+ num_threads=self.num_threads,
1533
+ micro_gemm=micro_gemm,
1534
+ is_dynamic_M=self.is_dynamic_M,
1535
+ template=self,
1536
+ kernel=kernel,
1537
+ export_declaration=get_export_declaration(),
1538
+ epilogue_nodes=epilogues,
1539
+ reindexers=reindexers,
1540
+ Y_2d=Y_2d,
1541
+ use_local_acc=use_local_acc,
1542
+ maybe_k_slicing=self.maybe_k_slicing(),
1543
+ x_scale=x_scale,
1544
+ x_zp=x_zp,
1545
+ w_scale=w_scale,
1546
+ w_zp=w_zp,
1547
+ acc_buf_dtype=torch.int32 if int8_gemm else torch.float,
1548
+ DTYPE_TO_CPP=DTYPE_TO_CPP,
1549
+ L1_cache_size=L1_cache_size,
1550
+ L2_cache_size=L2_cache_size,
1551
+ config=config,
1552
+ fake_buffers=fake_buffers,
1553
+ is_woq_int4=self.is_woq_int4(),
1554
+ q_group_size=q_group_size_node,
1555
+ qscale_and_zeros=qscale_and_zeros,
1556
+ )
1557
+ return options
1558
+
1559
+ def is_int8_woq_gemm_small_m_dim(
1560
+ self,
1561
+ X: ir.ReinterpretView,
1562
+ W: ir.ReinterpretView,
1563
+ N,
1564
+ K,
1565
+ micro_gemm,
1566
+ ):
1567
+ """Use SMALL_M_GEMM_TEMPLATE"""
1568
+ return (
1569
+ isinstance(micro_gemm, CppMicroGemmFP32Vec)
1570
+ and is_int8_woq_gemm_small_m_dim_corner_case(
1571
+ micro_gemm, X.get_size()[0], N, K
1572
+ )
1573
+ and X.get_dtype() is torch.bfloat16
1574
+ and W.get_dtype() is torch.int8
1575
+ )
1576
+
1577
+ def render( # type: ignore[override, return]
1578
+ self,
1579
+ kernel: CppTemplateKernel,
1580
+ template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
1581
+ flag_template_buffer_has_other_users: Optional[bool] = None,
1582
+ epilogue_nodes: Optional[list[ir.IRNode]] = None,
1583
+ **kwargs,
1584
+ ) -> str:
1585
+ options = self.get_options(
1586
+ kernel=kernel,
1587
+ template_buffer_node=template_buffer_node,
1588
+ flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
1589
+ epilogue_nodes=epilogue_nodes,
1590
+ )
1591
+ self.render_options = options
1592
+
1593
+ with contextlib.ExitStack() as stack:
1594
+ for buf in options["fake_buffers"]:
1595
+ stack.enter_context(
1596
+ patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
1597
+ )
1598
+ if not options["is_dynamic_M"] and self.is_int8_woq_gemm_small_m_dim(
1599
+ options["X"],
1600
+ options["W"],
1601
+ options["N"],
1602
+ options["K"],
1603
+ options["micro_gemm"],
1604
+ ):
1605
+ template_str = SMALL_M_GEMM_TEMPLATE
1606
+ else:
1607
+ template_str = GEMM_TEMPLATE
1608
+ return self._template_from_string(template_str).render(**options)
1609
+
1610
+ def codegen_blocks(
1611
+ self,
1612
+ num_threads,
1613
+ N,
1614
+ K,
1615
+ micro_gemm,
1616
+ is_dynamic_M,
1617
+ kernel,
1618
+ GemmOut,
1619
+ config,
1620
+ L1_cache_size,
1621
+ L2_cache_size,
1622
+ X,
1623
+ W,
1624
+ ):
1625
+ options = dict(
1626
+ num_threads=num_threads,
1627
+ N=N,
1628
+ K=K,
1629
+ micro_gemm=micro_gemm,
1630
+ is_dynamic_M=is_dynamic_M,
1631
+ kernel=kernel,
1632
+ GemmOut=GemmOut,
1633
+ config=config,
1634
+ L1_cache_size=L1_cache_size,
1635
+ L2_cache_size=L2_cache_size,
1636
+ template=self,
1637
+ X=X,
1638
+ W=W,
1639
+ is_woq_int4=self.is_woq_int4(),
1640
+ )
1641
+ template_str = GEMM_TEMPLATE_INIT_BLOCKING_BASIC_BLOCK
1642
+ if not (
1643
+ not is_dynamic_M
1644
+ and self.is_int8_woq_gemm_small_m_dim(X, W, N, K, micro_gemm)
1645
+ ):
1646
+ template_str += GEMM_TEMPLATE_INIT_BLOCKING_EXTENDED
1647
+ return self._template_from_string(template_str).render(options)
1648
+
1649
+ def codegen_microkernel_def(self):
1650
+ return self._template_from_string(GEMM_TEMPLATE_MICROKERNEL_DEF).render(
1651
+ self.render_options
1652
+ )
1653
+
1654
+ def codegen_gemm_stub_def(self):
1655
+ microkernel = self.codegen_microkernel_def()
1656
+ return microkernel + self._template_from_string(GEMM_TEMPLATE_STUB_DEF).render(
1657
+ self.render_options
1658
+ )
1659
+
1660
+ def codegen_multi_threads_params(self):
1661
+ return self._template_from_string(GEMM_TEMPLATE_MULTI_THREADS_PARAMS).render()
1662
+
1663
+ def codegen_single_thread_params(self, is_dynamic_M):
1664
+ options = dict(
1665
+ is_dynamic_M=is_dynamic_M,
1666
+ )
1667
+ return self._template_from_string(GEMM_TEMPLATE_SINGLE_THREAD_PARAMS).render(
1668
+ options
1669
+ )
1670
+
1671
+ def codegen_m_loop_params(self):
1672
+ return self._template_from_string(GEMM_TEMPLATE_M_LOOP_PARAMS).render()
1673
+
1674
+ def codegen_n_loop_params(self):
1675
+ return self._template_from_string(GEMM_TEMPLATE_N_LOOP_PARAMS).render()
1676
+
1677
+ @classmethod
1678
+ def is_woq_int4(cls):
1679
+ return False
1680
+
1681
+ @classmethod
1682
+ def q_group_size(cls):
1683
+ return None
1684
+
1685
+
1686
+ class CppWoqInt4GemmTemplateMeta(type):
1687
+ def __getitem__(cls, q_group_size):
1688
+ class CppWoqInt4GemmTemplateInstance(CppGemmTemplate):
1689
+ def __init__(
1690
+ self,
1691
+ *args,
1692
+ **kwargs,
1693
+ ) -> None:
1694
+ super().__init__(
1695
+ *args,
1696
+ **kwargs,
1697
+ )
1698
+
1699
+ @classmethod
1700
+ def is_woq_int4(cls):
1701
+ return True
1702
+
1703
+ @classmethod
1704
+ def q_group_size(cls):
1705
+ return q_group_size
1706
+
1707
+ @staticmethod
1708
+ def check_if_block_weight(W, micro_gemm):
1709
+ # For WOQ INT4, weight is already packed
1710
+ # However, for AMX microkernel, we want to change the blocking of weight
1711
+ from .cpp_micro_gemm import CppMicroGemmWoQInt4Amx
1712
+
1713
+ return isinstance(micro_gemm, CppMicroGemmWoQInt4Amx)
1714
+
1715
+ @classmethod
1716
+ def block_weight(cls, W, new_size, padding):
1717
+ # This method is called only if AMX microkernels are used.
1718
+ # In this case, we unpack and repack weight so that block_n=32
1719
+ # the format of packed weight is described here:
1720
+ # https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L583
1721
+ if isinstance(W, ir.IRNode):
1722
+ # in this case, we do nothing
1723
+ ir.ExternKernel.require_contiguous(W)
1724
+ blocked_w = W
1725
+ else:
1726
+ # in this case, we unpack and repack weight
1727
+ assert isinstance(W, torch.Tensor)
1728
+ assert W.dim() == 2
1729
+ N = W.size(0)
1730
+ K = W.size(-1) * 2
1731
+ G = cls.q_group_size()
1732
+ # x and qscales_and_zeros are in bfloat16 instead of float to use the optimized kernel
1733
+ # so that the unpacking process is faster
1734
+ x = torch.eye(K).bfloat16()
1735
+ # Here we use scale=1 and qzero=8 because we want to unpack weight
1736
+ # without dequantizing it. The qzero here is 8 instead of 0 because
1737
+ # int4 values are converted to [-7, 8] in the _weight_int4pack_mm_for_cpu kernel:
1738
+ # https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L95
1739
+ qscales_and_zeros = (
1740
+ torch.tensor([1.0, 8.0])
1741
+ .bfloat16()
1742
+ .expand(K // G, N, 2)
1743
+ .contiguous()
1744
+ )
1745
+ # shape: [K, N]
1746
+ unpacked_w = torch.ops.aten._weight_int4pack_mm_for_cpu(
1747
+ x,
1748
+ W,
1749
+ G,
1750
+ qscales_and_zeros,
1751
+ ).to(torch.uint8)
1752
+ block_n = 32
1753
+ # shape: [N // block_n, K, block_n]
1754
+ w_blocked = (
1755
+ unpacked_w.view(K, N // block_n, block_n)
1756
+ .permute(1, 0, 2)
1757
+ .contiguous()
1758
+ )
1759
+ # pack 2 int4 -> 1 int8
1760
+ # block_n: [a0, a1, ..., a15, b0, b1, ..., b15]
1761
+ # -> [(a0 & 0xf) | (b0 << 4), (a1 & 0xf) | (b1 << 4), ...]
1762
+ # shape: [N // block_n, K, 2, block_n // 2]
1763
+ w_blocked = w_blocked.view(N // block_n, K, 2, block_n // 2)
1764
+ # shape: [N // block_n, K, block_n // 2]
1765
+ w_blocked_packed = (w_blocked[:, :, 0, :] & 0xF) | (
1766
+ w_blocked[:, :, 1, :] << 4
1767
+ )
1768
+ # shape: [N, K // 2]
1769
+ blocked_w = w_blocked_packed.view(N, K // 2)
1770
+
1771
+ return blocked_w
1772
+
1773
+ return CppWoqInt4GemmTemplateInstance
1774
+
1775
+
1776
+ class CppWoqInt4GemmTemplate(metaclass=CppWoqInt4GemmTemplateMeta):
1777
+ pass
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_grouped_gemm_template.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import logging
3
+ from typing import Any, Callable, cast, Optional, TypeVar
4
+ from unittest.mock import patch
5
+
6
+ import torch
7
+ import torch.utils
8
+ from torch.utils._ordered_set import OrderedSet
9
+
10
+ from ..._dynamo.utils import counters
11
+ from .. import config, ir
12
+ from ..kernel.mm_common import mm_args
13
+ from ..select_algorithm import ChoiceCaller, DataProcessorTemplateWrapper
14
+ from ..utils import parallel_num_threads
15
+ from ..virtualized import V
16
+ from .cpp import get_export_declaration
17
+ from .cpp_gemm_template import (
18
+ CppGemmTemplate,
19
+ expand_bias,
20
+ gen_2d_view_of_epilogue_buf,
21
+ prune_tensors,
22
+ transpose_w,
23
+ )
24
+ from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm
25
+ from .cpp_template_kernel import CppTemplateKernel
26
+ from .cpp_utils import (
27
+ create_epilogue_with_attr,
28
+ DTYPE_TO_CPP,
29
+ GemmBlocking,
30
+ get_gemm_template_output_and_compute_dtype,
31
+ )
32
+
33
+
34
+ log = logging.getLogger(__name__)
35
+
36
+ GEMM_TEMPLATE = r"""
37
+ {{template.header().getvalue()}}
38
+ {{micro_gemm.codegen_define(kernel)}}
39
+
40
+ extern "C" {{export_declaration}}
41
+ {{kernel.def_kernel(inputs=kernel_args, outputs=Y_list, aliases=aliases)}}
42
+ {
43
+ {{kernel.maybe_codegen_profile()}}
44
+ {{ template.codegen_blocks(
45
+ num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOuts[0], config, L1_cache_size, L2_cache_size, X_list[0], W_list[0]
46
+ ) }}
47
+ {%- if num_threads > 1 %}
48
+ #pragma omp parallel num_threads({{num_threads}})
49
+ {
50
+ {{ template.codegen_multi_threads_params()|indent(8, false) }}
51
+ {%- else %}
52
+ {
53
+ {{ template.codegen_single_thread_params(is_dynamic_M)|indent(8, false) }}
54
+ {%- endif %}
55
+ {{ micro_gemm.codegen_init(kernel) }}
56
+ {%- set acc_buf_name_list=[] %}
57
+ {%- set acc_buf_name_prefix = "local_acc_buf_" %}
58
+ {%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
59
+ {%- set acc_buf_name = acc_buf_name_prefix + gemm_idx|string %}
60
+ {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
61
+ {%- set acc_buf_name_list=acc_buf_name_list.append(acc_buf_name) %}
62
+ {%- endfor %}
63
+ for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
64
+ {{ template.codegen_m_loop_params()|indent(12, false) }}
65
+ for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
66
+ {{ template.codegen_n_loop_params()|indent(16, false) }}
67
+ {%- set acc_list=[] %}
68
+ {%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
69
+ {%- set acc_list = acc_list.append( kernel.local_buffers[acc_buf_name_list[gemm_idx]] ) %}
70
+ {{ kernel.reinit_buffer_if_null(acc_buf_name_list[gemm_idx]) }}
71
+ {%- endfor %}
72
+ for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
73
+ int64_t k_start = kc * Kr;
74
+ int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
75
+ {%- set tile_X_list=[] %}
76
+ {%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
77
+ {%- set tile_X_list = tile_X_list.append( kernel.slice_nd(X_list[gemm_idx], [("m_start", "m_end"), ("k_start", "k_end")]) ) %}
78
+ {%- endfor %}
79
+ for (int64_t nci = nc; nci < nc_block_end; nci++) {
80
+ {%- set tile_W_3d_list=[] %}
81
+ {%- set tile_W_list=[] %}
82
+ {%- set acc_slice_list=[] %}
83
+ {%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
84
+ {%- set acc_slice_list = acc_slice_list.append(
85
+ kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")])
86
+ ) %}
87
+ {%- set tile_W_3d_list = tile_W_3d_list.append(
88
+ kernel.slice_nd(W_list[gemm_idx], [("nci", "nci + 1"), ("k_start", "k_end"), ()])
89
+ ) %}
90
+ {%- endfor %}
91
+ {%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
92
+ {%- set tile_W_list = tile_W_list.append(
93
+ kernel.view(tile_W_3d_list[gemm_idx], ["k_end - k_start", micro_gemm.register_blocking.block_n])
94
+ ) %}
95
+ {%- endfor %}
96
+ if (kc == k_block_start) {
97
+ {%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
98
+ {{ micro_gemm.codegen_call(
99
+ kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=False
100
+ )|indent(28, false) }}
101
+ {%- endfor %}
102
+ } else {
103
+ {%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
104
+ {{ micro_gemm.codegen_call(
105
+ kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=True
106
+ )|indent(28, false) }}
107
+ {%- endfor %}
108
+ }
109
+ }
110
+ }
111
+ {
112
+ {%- set tile_acc_list = [] %}
113
+ {%- set tile_Y_list = [] %}
114
+ {%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
115
+ {%- set tile_acc_list = tile_acc_list.append(
116
+ kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("0", "n_end - n_start")])
117
+ ) %}
118
+ {%- set tile_Y_list = tile_Y_list.append(
119
+ kernel.slice_nd(Y_2d_list[gemm_idx], [("m_start", "m_end"), ("n_start", "n_end")])
120
+ ) %}
121
+ {%- endfor %}
122
+ {{ kernel.store_outputs(
123
+ tile_Y_list,
124
+ tile_acc_list,
125
+ GemmOuts,
126
+ epilogue_nodes,
127
+ offsets=("m_start", "n_start"),
128
+ reindexers=reindexers,
129
+ multi_output_buffers=multi_output_buffers
130
+ )|indent(20, false)
131
+ }}
132
+ }
133
+ }
134
+ }
135
+ {{ micro_gemm.codegen_finalize(kernel) }}
136
+ }
137
+ }
138
+ """
139
+
140
+
141
+ def get_deduplicated_act(act_mapping: dict[int, ir.IRNode]) -> list[ir.IRNode]:
142
+ act_deduplicated = []
143
+ act_deduplicated_name: OrderedSet[str] = OrderedSet()
144
+ for act_idx in range(len(act_mapping.values())):
145
+ act = act_mapping[act_idx]
146
+ if act.get_name() not in act_deduplicated_name:
147
+ act_deduplicated.append(act)
148
+ act_deduplicated_name.add(act.get_name())
149
+ return act_deduplicated
150
+
151
+
152
+ class CppGroupedGemmTemplate(CppGemmTemplate):
153
+ def __init__(
154
+ self,
155
+ input_nodes: list[ir.IRNode],
156
+ layout: ir.Layout,
157
+ num_threads: int,
158
+ register_blocking: GemmBlocking,
159
+ beta: int = 1,
160
+ alpha: int = 1,
161
+ has_bias: bool = False,
162
+ epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
163
+ act_mapping: Optional[dict[int, ir.IRNode]] = None,
164
+ gemm_grouped_num: int = 1,
165
+ ) -> None:
166
+ """
167
+ Template for Group of GEMMs:
168
+ * Each GEMM has the same dimensions (m, n, k) and the same leading dimensions (lda, ldb, ldc)
169
+ for their A, B, and C matrices.
170
+ * Each GEMM has distinct or shared activations, has distinct weight, has unique bias or no bias, has distinct epilogues.
171
+ * In the current implementation, the outputs of all GEMMs are accumulated using pointwise epilogues.
172
+ This behavior can be extended in the future if needed.
173
+ """
174
+ super().__init__(
175
+ input_nodes,
176
+ layout,
177
+ num_threads,
178
+ register_blocking,
179
+ beta,
180
+ alpha,
181
+ has_bias,
182
+ epilogue_creator,
183
+ )
184
+ self.act_mapping = act_mapping
185
+ self.gemm_grouped_num = gemm_grouped_num
186
+ self.output_node: list[ir.Buffer] = [
187
+ ir.Buffer(name="buf_out" + str(idx), layout=layout)
188
+ for idx in range(gemm_grouped_num)
189
+ ]
190
+
191
+ @classmethod
192
+ def add_choices(
193
+ cls,
194
+ choices: list[ChoiceCaller],
195
+ layout: ir.Layout,
196
+ input_nodes: list[ir.IRNode],
197
+ beta: int = 1,
198
+ alpha: int = 1,
199
+ has_bias: tuple[bool, ...] = (False, False),
200
+ trans_w: bool = False,
201
+ input_indices: Optional[list[int]] = None,
202
+ epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
203
+ act_mapping: Optional[dict[int, ir.IRNode]] = None, # gemm idx to its act buf
204
+ ) -> DataProcessorTemplateWrapper:
205
+ # Input nodes order: x, optional[x1], ... w0, w1, ... optional[b0], optional[b1], ...
206
+ gemm_grouped_num = len(has_bias)
207
+ assert act_mapping
208
+ act_deduplicated = get_deduplicated_act(act_mapping)
209
+ wgt_start_idx = len(act_deduplicated)
210
+ bias_start_idx = wgt_start_idx + gemm_grouped_num
211
+ input_indices = list(range(len(input_nodes)))
212
+
213
+ _T = TypeVar("_T", ir.IRNode, torch.Tensor)
214
+ _U = TypeVar("_U", ir.Layout, torch.Tensor)
215
+
216
+ def reorder_and_filter(
217
+ inputs: list[_T],
218
+ layout_or_out: _U,
219
+ ) -> tuple[list[_T], _U]:
220
+ assert input_indices is not None, "input_indices must be set"
221
+ return [inputs[idx] for idx in input_indices], layout_or_out
222
+
223
+ new_inputs, new_layout = reorder_and_filter(input_nodes, layout)
224
+
225
+ def maybe_to_dense(
226
+ inputs: list[_T],
227
+ layout_or_out: _U,
228
+ ) -> tuple[list[_T], _U]:
229
+ new_inputs = list(inputs)
230
+ for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
231
+ if isinstance(inputs[idx], torch.Tensor):
232
+ W = inputs[idx]
233
+ assert isinstance(W, torch.Tensor), "W must be a torch.Tensor"
234
+ new_inputs[idx] = W.to_dense() if W.is_mkldnn else W
235
+ return new_inputs, layout_or_out
236
+
237
+ def normalize_shapes(
238
+ inputs: list[_T],
239
+ layout_or_out: _U,
240
+ ) -> tuple[list[_T], _U]:
241
+ new_inputs: list[_T] = list(inputs)
242
+ if not trans_w:
243
+ return new_inputs, layout_or_out
244
+ X = new_inputs[0]
245
+ for wgt_idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
246
+ new_input = new_inputs[wgt_idx]
247
+ new_inputs[wgt_idx] = transpose_w(new_input, trans_w)
248
+ for bias_idx in range(bias_start_idx, len(new_inputs)):
249
+ new_bias = expand_bias(new_inputs[bias_idx], X)
250
+ assert new_bias is not None
251
+ new_inputs[bias_idx] = new_bias
252
+ return new_inputs, layout_or_out
253
+
254
+ num_threads = parallel_num_threads()
255
+ new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout))
256
+ m, n, k, *_ = mm_args(new_inputs[0], new_inputs[wgt_start_idx])
257
+ output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
258
+ new_inputs[0].get_dtype()
259
+ )
260
+ micro_gemm = create_micro_gemm(
261
+ "micro_gemm",
262
+ m,
263
+ n,
264
+ k,
265
+ input_dtype=new_inputs[0].get_dtype(),
266
+ input2_dtype=new_inputs[wgt_start_idx].get_dtype(),
267
+ output_dtype=output_dtype,
268
+ compute_dtype=compute_dtype,
269
+ alpha=alpha,
270
+ num_threads=num_threads,
271
+ )
272
+ assert micro_gemm is not None
273
+ _, block_n, _ = micro_gemm.register_blocking
274
+ new_size, padded_n = cls.get_padded_size(
275
+ n, block_n, k, should_block_weight=True
276
+ )
277
+ padding = padded_n - n
278
+
279
+ def pack_weight(
280
+ inputs: list[_T],
281
+ layout_or_out: _U,
282
+ ) -> tuple[list[_T], _U]:
283
+ new_W_list = []
284
+ new_inputs = list(inputs)
285
+ W_list = new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num]
286
+ for W in W_list:
287
+ blocked_w = cls.block_weight(W, new_size, padding)
288
+ new_W_list.append(cls.pack_vnni_weight(blocked_w, micro_gemm, new_size))
289
+ new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = new_W_list
290
+ return new_inputs, layout_or_out
291
+
292
+ def preprocessor(
293
+ inputs: list[_T],
294
+ layout: _U,
295
+ ) -> tuple[list[_T], _U]:
296
+ return pack_weight(
297
+ *normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout)))
298
+ )
299
+
300
+ def postprocessor(output: _T) -> _T:
301
+ if isinstance(output, ir.TensorBox):
302
+ template_buffer = ir.InputsKernel.unwrap_storage_for_input(output)
303
+ assert isinstance(template_buffer, ir.CppTemplateBuffer)
304
+ new_input_nodes, _ = reorder_and_filter(input_nodes, layout)
305
+ W_nodes = new_input_nodes[
306
+ wgt_start_idx : wgt_start_idx + gemm_grouped_num
307
+ ]
308
+ W_tensor = []
309
+ for W_node in W_nodes:
310
+ assert W_node.get_name() in V.graph.constants
311
+ W_tensor.append(V.graph.constants[W_node.get_name()])
312
+ new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = (
313
+ W_tensor # type: ignore[assignment]
314
+ )
315
+ new_input_nodes, _ = pack_weight(
316
+ *normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
317
+ )
318
+ # Prune unused tensors
319
+ prune_tensors(input_nodes, new_input_nodes)
320
+ for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
321
+ W_packed = new_input_nodes[idx]
322
+ assert isinstance(W_packed, torch.Tensor)
323
+ W_packed_constant = V.graph.add_tensor_constant(W_packed)
324
+ template_buffer.inputs[idx] = (
325
+ ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
326
+ )
327
+ return output
328
+
329
+ template = DataProcessorTemplateWrapper(
330
+ CppGroupedGemmTemplate,
331
+ preprocessor,
332
+ postprocessor,
333
+ input_nodes=input_nodes,
334
+ layout=layout,
335
+ num_threads=num_threads,
336
+ register_blocking=micro_gemm.register_blocking,
337
+ beta=beta,
338
+ alpha=alpha,
339
+ has_bias=has_bias,
340
+ epilogue_creator=epilogue_creator,
341
+ act_mapping=act_mapping,
342
+ gemm_grouped_num=gemm_grouped_num,
343
+ )
344
+ template.maybe_append_choice(choices)
345
+ return template
346
+
347
+ def render( # type: ignore[override,return,no-untyped-def]
348
+ self,
349
+ kernel: CppTemplateKernel,
350
+ template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
351
+ flag_template_buffer_has_other_users: Optional[bool] = None,
352
+ epilogue_nodes: Optional[list[ir.IRNode]] = None,
353
+ **kwargs,
354
+ ) -> str:
355
+ assert self.act_mapping
356
+ act_deduplicated = get_deduplicated_act(self.act_mapping)
357
+ wgt_start_idx = len(act_deduplicated)
358
+ bias_start_idx = wgt_start_idx + self.gemm_grouped_num
359
+ X_list = list(self.act_mapping.values())
360
+ W_list = self.input_nodes[wgt_start_idx : wgt_start_idx + self.gemm_grouped_num]
361
+ inp_list = []
362
+ cur_idx = bias_start_idx
363
+ for inp_idx in range(self.gemm_grouped_num):
364
+ inp = None
365
+ if self.has_bias[inp_idx]:
366
+ inp = self.input_nodes[cur_idx]
367
+ cur_idx += 1
368
+ inp_list.append(inp)
369
+
370
+ Y_list = self.output_node
371
+ multi_output_buffers = None
372
+ if template_buffer_node is not None:
373
+ W_list = template_buffer_node.inputs[
374
+ wgt_start_idx : wgt_start_idx + self.gemm_grouped_num
375
+ ]
376
+ assert isinstance(template_buffer_node.outputs, list)
377
+ Y_list = template_buffer_node.outputs
378
+ counters["inductor"]["cpp_grouped_gemm_template"] += 1
379
+ multi_output_buffers = template_buffer_node.outputs
380
+
381
+ template_buffer = Y_list[0]
382
+ fake_buffers: list[ir.Buffer] = []
383
+ Y_2d_list = Y_list
384
+ output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
385
+ X_list[0].get_dtype()
386
+ )
387
+ micro_gemm = create_micro_gemm(
388
+ f"{kernel.kernel_name}_micro_gemm",
389
+ self.m,
390
+ self.n,
391
+ self.k,
392
+ input_dtype=X_list[0].get_dtype(),
393
+ input2_dtype=W_list[0].get_dtype(),
394
+ output_dtype=output_dtype,
395
+ compute_dtype=compute_dtype,
396
+ alpha=self.alpha,
397
+ num_threads=self.num_threads,
398
+ )
399
+ assert micro_gemm is not None
400
+ assert self.register_blocking == micro_gemm.register_blocking
401
+ self.log_blockings()
402
+ if isinstance(micro_gemm, CppMicroGemmAMX):
403
+ counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
404
+
405
+ L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes
406
+ assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
407
+
408
+ L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes
409
+ assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
410
+
411
+ epilogues: list[ir.IRNode] = []
412
+ reindexers: list[Optional[Callable[[list[Any]], list[Any]]]] = []
413
+ gemm_output_buffers: list[ir.Buffer] = []
414
+ for out_buf_idx in range(self.gemm_grouped_num):
415
+ gemm_output_name = f"{template_buffer.get_name()}_GemmOut" + str(
416
+ out_buf_idx
417
+ )
418
+ gemm_output_buffers.append(
419
+ ir.Buffer(name=gemm_output_name, layout=template_buffer.layout)
420
+ )
421
+
422
+ assert not self.epilogue_creator, (
423
+ "epilogue_creator is not supported yet in Grouped GEMM Template"
424
+ )
425
+
426
+ kernel_args: dict[str, Optional[ir.IRNode]] = {}
427
+ for x_idx in range(wgt_start_idx):
428
+ kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx]
429
+ for w_idx in range(self.gemm_grouped_num):
430
+ kernel_args["W" + str(w_idx)] = W_list[w_idx]
431
+ for inp_idx in range(self.gemm_grouped_num):
432
+ kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx]
433
+
434
+ def _bias_add_epilogue(buf: ir.IRNode, inp: ir.IRNode) -> ir.Pointwise:
435
+ return create_epilogue_with_attr(
436
+ buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype
437
+ )
438
+
439
+ for gemm_idx, inp in enumerate(inp_list):
440
+ if inp:
441
+ buffer_name = Y_list[gemm_idx].get_name()
442
+ epilogues.append(
443
+ ir.ComputedBuffer(
444
+ name=buffer_name,
445
+ layout=template_buffer.layout,
446
+ data=_bias_add_epilogue(gemm_output_buffers[gemm_idx], inp),
447
+ )
448
+ )
449
+ reindexers.append(None)
450
+
451
+ if epilogue_nodes:
452
+ epilogues.extend(epilogue_nodes)
453
+ for epilogue_node in epilogue_nodes:
454
+ Y = cast(ir.Buffer, epilogue_node)
455
+ _, reindexers = gen_2d_view_of_epilogue_buf(
456
+ Y,
457
+ template_buffer,
458
+ [
459
+ epilogue_node,
460
+ ],
461
+ reindexers,
462
+ default_reindexers=[
463
+ None,
464
+ ],
465
+ )
466
+
467
+ options = dict(
468
+ N=self.n,
469
+ K=self.k,
470
+ PADDED_N=self.padded_n,
471
+ aliases={},
472
+ beta=self.beta,
473
+ alpha=self.alpha,
474
+ num_threads=self.num_threads,
475
+ micro_gemm=micro_gemm,
476
+ is_dynamic_M=self.is_dynamic_M,
477
+ template=self,
478
+ kernel=kernel,
479
+ export_declaration=get_export_declaration(),
480
+ acc_buf_dtype=torch.float,
481
+ DTYPE_TO_CPP=DTYPE_TO_CPP,
482
+ L1_cache_size=L1_cache_size,
483
+ L2_cache_size=L2_cache_size,
484
+ config=config,
485
+ epilogue_nodes=epilogues,
486
+ GemmOuts=gemm_output_buffers,
487
+ reindexers=reindexers,
488
+ kernel_args=kernel_args,
489
+ X_list=X_list,
490
+ W_list=W_list,
491
+ gemm_grouped_num=self.gemm_grouped_num,
492
+ Y_list={"Y" + str(idx): Y for idx, Y in enumerate(Y_list)},
493
+ Y_2d_list=Y_2d_list,
494
+ multi_output_buffers=multi_output_buffers,
495
+ )
496
+ with contextlib.ExitStack() as stack:
497
+ stack.enter_context(
498
+ patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_buffers))
499
+ )
500
+ return self._template_from_string(GEMM_TEMPLATE).render(**options)
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py ADDED
@@ -0,0 +1,2011 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import dataclasses
3
+ import operator
4
+ import sys
5
+ from enum import Enum
6
+ from typing import Callable, Optional
7
+
8
+ import torch
9
+
10
+ from .. import cpp_builder, ir
11
+ from ..cpu_vec_isa import (
12
+ pick_vec_isa,
13
+ VecAMX,
14
+ VecAVX2,
15
+ VecAVX512,
16
+ VecISA,
17
+ VecNEON,
18
+ VecSVE256,
19
+ )
20
+ from ..utils import IndentedBuffer, parallel_num_threads
21
+ from ..virtualized import V
22
+ from .common import KernelTemplate
23
+ from .cpp_template_kernel import CppTemplateKernel
24
+ from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp
25
+
26
+
27
+ class LayoutType(Enum):
28
+ NORMAL = 0
29
+ VNNI2 = 1
30
+ VNNI4 = 2
31
+
32
+
33
+ _IS_WINDOWS = sys.platform == "win32"
34
+
35
+
36
+ def get_restrict_keyword() -> str:
37
+ if _IS_WINDOWS:
38
+ # https://learn.microsoft.com/en-us/cpp/cpp/extension-restrict?view=msvc-170
39
+ return "__restrict"
40
+ else:
41
+ return "__restrict__"
42
+
43
+
44
+ class CppMicroGemm:
45
+ """
46
+ A class that codegens a kernel that computes small-sized matrix multiplication.
47
+
48
+ A micro GEMM kernel is responsible for register blocking, instruction selection,
49
+ and other CPU architecture-specific optimizations.
50
+
51
+ The subclasses need to override `codegen_define` to define the kernel function
52
+ that is called by the code generated by `codegen_call`.
53
+ """
54
+
55
+ # TODO(jgong5): support constant shapes and lds as template args.
56
+ DECLARE_KERNEL = r"""
57
+ template <bool accum, bool prefetch=false>
58
+ inline void {{kernel_name}}(
59
+ {%- if kernel_extra_args_declare %}
60
+ {{kernel_extra_args_declare}}
61
+ {%- endif %}
62
+ const {{input_t}}* {{restrict_keyword}} A,
63
+ const {{input2_t}}* {{restrict_keyword}} B,
64
+ {{output_t}}* {{restrict_keyword}} C,
65
+ int64_t M,
66
+ int64_t N,
67
+ int64_t K,
68
+ int64_t lda,
69
+ int64_t ldb,
70
+ int64_t ldc
71
+ )
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ name,
77
+ input_dtype,
78
+ input2_dtype,
79
+ output_dtype,
80
+ compute_dtype,
81
+ register_blocking,
82
+ alpha=1,
83
+ ) -> None:
84
+ self.name = name
85
+ self.input_dtype = input_dtype
86
+ assert input2_dtype is not None
87
+ self.input2_dtype = input2_dtype
88
+ self.output_dtype = output_dtype
89
+ self.compute_dtype = compute_dtype
90
+ self.register_blocking = register_blocking
91
+ self.alpha = alpha
92
+ self.pack_vnni_B_locally = False
93
+
94
+ def get_common_options(self):
95
+ if self.input_dtype in [torch.uint8, torch.int8]:
96
+ assert self.compute_dtype == torch.int32
97
+ assert self.output_dtype == torch.int32
98
+ assert self.input2_dtype == torch.int8
99
+ return {
100
+ "torch": torch,
101
+ "kernel_name": self.name,
102
+ "input_dtype": self.input_dtype,
103
+ "input2_dtype": self.input2_dtype,
104
+ "output_dtype": self.output_dtype,
105
+ "compute_dtype": self.compute_dtype,
106
+ "input_t": DTYPE_TO_CPP[self.input_dtype],
107
+ "input2_t": DTYPE_TO_CPP[self.input2_dtype],
108
+ "output_t": DTYPE_TO_CPP[self.output_dtype],
109
+ "compute_t": DTYPE_TO_CPP[self.compute_dtype],
110
+ "alpha": self.alpha,
111
+ "kernel_extra_args_declare": self.get_kernel_extra_args_declare(),
112
+ "int8_gemm": self.input_dtype in [torch.uint8, torch.int8],
113
+ "vnni_size": 4 if self.input_dtype in [torch.uint8, torch.int8] else 2,
114
+ "restrict_keyword": get_restrict_keyword(),
115
+ "pack_vnni_B_locally": self.pack_vnni_B_locally,
116
+ "template": self,
117
+ "is_woq_int4": self.is_woq_int4(),
118
+ }
119
+
120
+ def get_kernel_declaration(self):
121
+ options = self.get_common_options()
122
+ return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options)
123
+
124
+ def get_kernel_extra_args_declare(self) -> str:
125
+ return ""
126
+
127
+ def get_kernel_extra_args(self, **kwargs) -> list[str]:
128
+ return []
129
+
130
+ def codegen_define(self, kernel: CppTemplateKernel) -> str:
131
+ raise NotImplementedError
132
+
133
+ def codegen_call(
134
+ self,
135
+ kernel: CppTemplateKernel,
136
+ A: ir.Buffer,
137
+ B: ir.Buffer,
138
+ C: ir.Buffer,
139
+ accum: bool,
140
+ prefetch: bool = False,
141
+ **kwargs_for_extra_args,
142
+ ) -> str:
143
+ """
144
+ Generate the code for calling the templated kernel that computes
145
+ `C += alpha * A @ B` if `accum` is True, or `C = alpha * A @ B` otherwise.
146
+ """
147
+ A_ptr = f"&({kernel.index(A, [0, 0])})"
148
+ B_ptr = f"&({kernel.index(B, [0, 0])})"
149
+ C_ptr = f"&({kernel.index(C, [0, 0])})"
150
+ M = kernel.size(C, 0)
151
+ N = kernel.size(C, 1)
152
+ K = kernel.size(A, 1)
153
+ lda = kernel.stride(A, 0)
154
+ ldb = kernel.stride(B, 0)
155
+ ldc = kernel.stride(C, 0)
156
+ res = IndentedBuffer()
157
+ res.writeline(
158
+ f"{self.name}<{value_to_cpp(accum, 'bool')}, {value_to_cpp(prefetch, 'bool')}>("
159
+ )
160
+ with res.indent():
161
+ kwargs_for_extra_args.update({"kernel": kernel})
162
+ extra_args = self.get_kernel_extra_args(**kwargs_for_extra_args)
163
+ for arg in extra_args:
164
+ res.writeline(arg)
165
+ res.writeline(f"{A_ptr},")
166
+ res.writeline(f"{B_ptr},")
167
+ res.writeline(f"{C_ptr},")
168
+ res.writeline(f"{M},")
169
+ res.writeline(f"{N},")
170
+ res.writeline(f"{K},")
171
+ res.writeline(f"{lda},")
172
+ res.writeline(f"{ldb},")
173
+ res.writeline(f"{ldc}")
174
+ res.writeline(");")
175
+ return res.getvalue()
176
+
177
+ def use_local_vnni_blocking(self, should_block_weight: bool):
178
+ self.pack_vnni_B_locally = should_block_weight
179
+
180
+ def codegen_init(
181
+ self,
182
+ kernel: CppTemplateKernel,
183
+ ) -> str:
184
+ return ""
185
+
186
+ def codegen_finalize(
187
+ self,
188
+ kernel: CppTemplateKernel,
189
+ ) -> str:
190
+ return ""
191
+
192
+ def get_b_layout(self) -> LayoutType:
193
+ return LayoutType.NORMAL
194
+
195
+ ALLOCATE_WEIGHT_BUFFER = r"""
196
+ {%- if is_msvc_compiler %}
197
+ // MSVC doesn't support stack-allocated dynamic-sized arrays, so using heap memory here.
198
+ std::unique_ptr<{{buffer_dtype}}[]> heap_deq_b_buf_ptr(new {{buffer_dtype}}[{{buffer_size}}]);
199
+ {{buffer_dtype}}* {{buffer_name}} = heap_deq_b_buf_ptr.get();
200
+ {%- else %}
201
+ // It's safe to use a stack-allocated array since the blocking strategy would
202
+ // require us to allocate an array that's smaller than the size of L1D cache,
203
+ // and the default per thread max stack size on Linux is quite higher,
204
+ // so we need not worry about stack overflow.
205
+ alignas(4096) {{buffer_dtype}} {{buffer_name}}[{{buffer_size}}];
206
+ {%- endif %}
207
+ """
208
+
209
+ def codegen_allocate_weight_buffer(
210
+ self, buffer_name: str, buffer_dtype: str, *size_args
211
+ ) -> str:
212
+ buffer_size = " * ".join(map(str, size_args))
213
+ return KernelTemplate._template_from_string(self.ALLOCATE_WEIGHT_BUFFER).render(
214
+ dict(
215
+ buffer_name=buffer_name,
216
+ buffer_dtype=buffer_dtype,
217
+ buffer_size=buffer_size,
218
+ is_msvc_compiler=cpp_builder.is_msvc_cl(),
219
+ )
220
+ )
221
+
222
+ def is_woq_int4(self):
223
+ return False
224
+
225
+
226
+ @dataclasses.dataclass
227
+ class CppMicroGemmConfig:
228
+ input_dtype: torch.dtype
229
+ input2_dtype: torch.dtype
230
+ output_dtype: torch.dtype
231
+ compute_dtype: torch.dtype
232
+ vec_isa_cls: type[VecISA]
233
+ register_blocking: GemmBlocking
234
+ extra_check: Optional[Callable[..., bool]] = None
235
+
236
+
237
+ micro_gemm_configs: dict[type[CppMicroGemm], list[CppMicroGemmConfig]] = {}
238
+
239
+
240
+ def register_micro_gemm(*configs):
241
+ def inner(cls):
242
+ assert cls not in micro_gemm_configs, (
243
+ f"Duplicate micro_gemm registration for {cls}"
244
+ )
245
+ assert len(configs) > 0, f"No micro_gemm configs provided for {cls}"
246
+ micro_gemm_configs[cls] = list(configs)
247
+ return cls
248
+
249
+ return inner
250
+
251
+
252
+ def generate_gemm_config(
253
+ vec_isa_cls,
254
+ register_blockings,
255
+ input_dtype=torch.float,
256
+ input2_dtype=None,
257
+ output_dtype=None,
258
+ compute_dtype=None,
259
+ extra_check=None,
260
+ ):
261
+ if output_dtype is None:
262
+ output_dtype = input_dtype
263
+ if compute_dtype is None:
264
+ compute_dtype = output_dtype
265
+ if input2_dtype is None:
266
+ input2_dtype = input_dtype
267
+ return [
268
+ CppMicroGemmConfig(
269
+ input_dtype,
270
+ input2_dtype,
271
+ output_dtype,
272
+ compute_dtype,
273
+ vec_isa_cls,
274
+ GemmBlocking(*blocking),
275
+ extra_check,
276
+ )
277
+ for blocking in register_blockings
278
+ ]
279
+
280
+
281
+ class CppMicroGemmRef(CppMicroGemm):
282
+ """
283
+ A reference implementation of the CppMicroGemm class with naive C++ code.
284
+ It is used for correctness debugging.
285
+ """
286
+
287
+ TEMPLATE_ENTRY = r"""
288
+ {{declare_kernel}} {
289
+ for (int64_t m = 0; m < M; ++m) {
290
+ for (int64_t n = 0; n < N; ++n) {
291
+ {{compute_t}} result = accum ? C[m * ldc + n] : 0;
292
+ for (int64_t k = 0; k < K; ++k) {
293
+ result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}};
294
+ }
295
+ C[m * ldc + n] = result;
296
+ }
297
+ }
298
+ }
299
+ """
300
+
301
+ def __init__(
302
+ self, name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha
303
+ ) -> None:
304
+ super().__init__(
305
+ name,
306
+ input_dtype,
307
+ input2_dtype,
308
+ output_dtype,
309
+ compute_dtype,
310
+ GemmBlocking(1, 1, 1),
311
+ alpha,
312
+ )
313
+
314
+ def codegen_define(self, kernel: CppTemplateKernel) -> str:
315
+ options = {
316
+ "declare_kernel": self.get_kernel_declaration(),
317
+ **self.get_common_options(),
318
+ }
319
+ return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options)
320
+
321
+
322
+ def is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k):
323
+ return (
324
+ k % config.register_blocking.block_k == 0
325
+ and n % config.register_blocking.block_n == 0
326
+ and m < 16
327
+ )
328
+
329
+
330
+ # extra check for small M dimension for int8 WoQ case
331
+ def check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs):
332
+ return is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k) and not kwargs.get(
333
+ "dynamic_M", False
334
+ )
335
+
336
+
337
+ # For int8 WoQ GEMM with small M, we use different blockings that shouldn't be used otherwise
338
+ def do_not_use_with_small_m_for_int8_woq(config, m, n, k, alpha, num_threads, **kwargs):
339
+ return not check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs)
340
+
341
+
342
+ @register_micro_gemm(
343
+ *generate_gemm_config(
344
+ VecAVX512,
345
+ [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
346
+ input_dtype=torch.float,
347
+ ),
348
+ *generate_gemm_config(
349
+ VecAVX512,
350
+ [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
351
+ input_dtype=torch.bfloat16,
352
+ output_dtype=torch.float,
353
+ ),
354
+ *generate_gemm_config(
355
+ VecAVX512,
356
+ [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
357
+ input_dtype=torch.half,
358
+ output_dtype=torch.float,
359
+ ),
360
+ *generate_gemm_config(
361
+ VecAVX512,
362
+ [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
363
+ input_dtype=torch.bfloat16,
364
+ input2_dtype=torch.int8,
365
+ output_dtype=torch.float,
366
+ compute_dtype=torch.float,
367
+ extra_check=do_not_use_with_small_m_for_int8_woq,
368
+ ),
369
+ *generate_gemm_config(
370
+ VecAVX512,
371
+ [
372
+ (4, 32, 64),
373
+ (8, 32, 64),
374
+ ],
375
+ input_dtype=torch.bfloat16,
376
+ input2_dtype=torch.int8,
377
+ output_dtype=torch.float,
378
+ compute_dtype=torch.float,
379
+ extra_check=check_int8_woq_small_m_dim,
380
+ ),
381
+ *generate_gemm_config(
382
+ VecAVX2,
383
+ [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
384
+ input_dtype=torch.float,
385
+ ),
386
+ *generate_gemm_config(
387
+ VecAVX2,
388
+ [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
389
+ input_dtype=torch.bfloat16,
390
+ output_dtype=torch.float,
391
+ ),
392
+ *generate_gemm_config(
393
+ VecAVX2,
394
+ [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
395
+ input_dtype=torch.half,
396
+ output_dtype=torch.float,
397
+ ),
398
+ *generate_gemm_config(
399
+ VecAVX2,
400
+ [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
401
+ input_dtype=torch.bfloat16,
402
+ input2_dtype=torch.int8,
403
+ output_dtype=torch.float,
404
+ compute_dtype=torch.float,
405
+ extra_check=do_not_use_with_small_m_for_int8_woq,
406
+ ),
407
+ *generate_gemm_config(
408
+ VecAVX2,
409
+ [
410
+ (2, 16, 64),
411
+ (4, 16, 64),
412
+ ],
413
+ input_dtype=torch.bfloat16,
414
+ input2_dtype=torch.int8,
415
+ output_dtype=torch.float,
416
+ compute_dtype=torch.float,
417
+ extra_check=check_int8_woq_small_m_dim,
418
+ ),
419
+ *generate_gemm_config(
420
+ VecNEON,
421
+ [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
422
+ input_dtype=torch.float,
423
+ input2_dtype=torch.float,
424
+ output_dtype=torch.float,
425
+ compute_dtype=torch.float,
426
+ ),
427
+ *generate_gemm_config(
428
+ VecSVE256,
429
+ [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
430
+ input_dtype=torch.float,
431
+ input2_dtype=torch.float,
432
+ output_dtype=torch.float,
433
+ compute_dtype=torch.float,
434
+ ),
435
+ )
436
+ class CppMicroGemmFP32Vec(CppMicroGemm):
437
+ """
438
+ This class generates the code for micro gemm using fp32 vec instructions for compute.
439
+ It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output.
440
+ The output of the microkernel is in FP32, but it would be converted to BF16/FP16 in the template,
441
+ if the desired output is BF16/FP16.
442
+ """
443
+
444
+ TEMPLATE_ENTRY = r"""
445
+ {{declare_kernel}} {
446
+ using Vectorized = at::vec::Vectorized<{{compute_t}}>;
447
+ constexpr auto VLEN = Vectorized::size();
448
+ {{kernel.assert_function}}({{block_n}} % VLEN == 0, "block_n dimension must be multiple of Vector size");
449
+ {{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
450
+ // TODO(jgong5): loop unroll for M and N
451
+ for (int64_t m = 0; m < M; m += {{block_m}}) {
452
+ int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
453
+ for (int64_t n = 0; n < N; n += {{block_n}}) {
454
+ int64_t block_n = std::min<int64_t>(N - n, {{block_n}});
455
+ if (block_m == {{block_m}} && block_n == {{block_n}}) {
456
+ {%- if not trans_b %}
457
+ {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum, prefetch>(
458
+ {%- else %}
459
+ {{kernel_name}}_transpose_b_kernel<{{block_m}}, {{block_n}}, accum, prefetch>(
460
+ {%- endif %}
461
+ A + m * lda,
462
+ {%- if not trans_b %}
463
+ B + n,
464
+ {%- else %}
465
+ B + n * ldb,
466
+ {%- endif %}
467
+ C + m * ldc + n,
468
+ K,
469
+ lda,
470
+ ldb,
471
+ ldc
472
+ );
473
+ {%- if tail_n %}
474
+ } else if (block_n == {{block_n}}){
475
+ {%- else %}
476
+ } else {
477
+ {%- endif %}
478
+ switch (block_m) {
479
+ {%- for b in range(block_m - 1, 0, -1) %}
480
+ case {{b}}:
481
+ {%- if not trans_b %}
482
+ {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum, prefetch>(
483
+ {%- else %}
484
+ {{kernel_name}}_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>(
485
+ {%- endif %}
486
+ A + m * lda,
487
+ {%- if not trans_b %}
488
+ B + n,
489
+ {%- else %}
490
+ B + n * ldb,
491
+ {%- endif %}
492
+ C + m * ldc + n,
493
+ K,
494
+ lda,
495
+ ldb,
496
+ ldc
497
+ );
498
+ break;
499
+ {%- endfor %}
500
+ default:
501
+ {{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}");
502
+ }
503
+
504
+ {%- if tail_n %}
505
+ } else {
506
+ switch (block_m) {
507
+ {%- for b in range(block_m, 0, -1) %}
508
+ case {{b}}:
509
+ {%- if not trans_b %}
510
+ {{kernel_name}}_ntail_kernel<{{b}}, {{block_n}}, accum, prefetch>(
511
+ {%- else %}
512
+ {{kernel_name}}_ntail_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>(
513
+ {%- endif %}
514
+ A + m * lda,
515
+ {%- if not trans_b %}
516
+ B + n,
517
+ {%- else %}
518
+ B + n * ldb,
519
+ {%- endif %}
520
+ C + m * ldc + n,
521
+ block_n,
522
+ K,
523
+ lda,
524
+ ldb,
525
+ ldc
526
+ );
527
+ break;
528
+ {%- endfor %}
529
+ default:
530
+ {{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}");
531
+ }
532
+ }
533
+ {%- else %}
534
+ }
535
+ {%- endif %}
536
+ }
537
+ }
538
+ }
539
+ """
540
+
541
+ TEMPLATE_KERNEL = r"""
542
+
543
+ template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum, bool prefetch=false>
544
+ {%- if not trans_b %}
545
+ {%- if tail_n %}
546
+ inline void {{kernel_name}}_ntail_kernel(
547
+ {%- else %}
548
+ inline void {{kernel_name}}_kernel(
549
+ {%- endif %}
550
+ {%- else %}
551
+ {%- if tail_n %}
552
+ inline void {{kernel_name}}_ntail_transpose_b_kernel(
553
+ {%- else %}
554
+ inline void {{kernel_name}}_transpose_b_kernel(
555
+ {%- endif %}
556
+ {%- endif %}
557
+ const {{input_t}}* {{restrict_keyword}} A,
558
+ const {{input2_t}}* {{restrict_keyword}} B,
559
+ {{output_t}}* {{restrict_keyword}} C,
560
+ {%- if tail_n %}
561
+ int64_t N,
562
+ {%- endif %}
563
+ int64_t K,
564
+ int64_t lda,
565
+ int64_t ldb,
566
+ int64_t ldc
567
+ ) {
568
+ using Vectorized = at::vec::Vectorized<{{compute_t}}>;
569
+ {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
570
+ using VectorizedIn = at::vec::Vectorized<{{input_t}}>;
571
+ {%- endif %}
572
+
573
+ {%- if not trans_b %}
574
+ constexpr auto VLEN = Vectorized::size();
575
+ constexpr auto ROWS = BLOCK_M;
576
+ constexpr auto COLS = BLOCK_N / VLEN;
577
+
578
+ Vectorized va;
579
+ at::vec::VectorizedN<{{compute_t}}, COLS> vb;
580
+ at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc;
581
+
582
+ {%- if tail_n %}
583
+ int64_t rCOLS = (N + VLEN - 1) / VLEN;
584
+ int ntail = N % VLEN;
585
+ {%- endif %}
586
+ auto loadc = [&](auto i) {
587
+ if constexpr (accum) {
588
+ constexpr int row = i / COLS;
589
+ constexpr int col = i % COLS;
590
+ {%- if tail_n %}
591
+ int load_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN;
592
+ if (col < rCOLS) {
593
+ vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN, load_size);
594
+ }
595
+ {%- else %}
596
+ vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN);
597
+ {%- endif %}
598
+ } else {
599
+ vc[i] = Vectorized(0.0f);
600
+ }
601
+ };
602
+ c10::ForcedUnroll<ROWS * COLS>{}(loadc);
603
+
604
+ auto compute = [&, COLS](auto i, int k) {
605
+ constexpr int row = i / COLS;
606
+ constexpr int col = i % COLS;
607
+ {%- if tail_n %}
608
+ int load_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN;
609
+ {%- endif %}
610
+ if constexpr (col == 0) {
611
+ {%- if alpha != 1 %}
612
+ va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}});
613
+ {%- else %}
614
+ va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]));
615
+ {%- endif %}
616
+ }
617
+
618
+ if constexpr (row == 0) {
619
+ {%- if tail_n %}
620
+ if (col < rCOLS) {
621
+ {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
622
+ auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, load_size);
623
+ vb[col] = at::vec::convert<{{compute_t}}>(b);
624
+ {%- elif input2_dtype == torch.int8 %}
625
+ // Convert VLEN int8 elements to int32, and then fp32
626
+ auto b32 = at::vec::convert_to_int32<int8_t>(B + k * ldb + col * VLEN, load_size);
627
+ vb[col] = at::vec::convert<float>(b32);
628
+ {%- else %}
629
+ vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN, load_size);
630
+ {%- endif %}
631
+ } else {
632
+ vb[col] = Vectorized(0.0f);
633
+ }
634
+
635
+ {%- else %}
636
+
637
+ {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
638
+ auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN);
639
+ vb[col] = at::vec::convert<{{compute_t}}>(b);
640
+ {%- elif input2_dtype == torch.int8 %}
641
+ // Convert VLEN int8 elements to int32, and then fp32
642
+ auto b32 = at::vec::convert_to_int32<int8_t>(B + k * ldb + col * VLEN);
643
+ if constexpr (prefetch) {
644
+ _mm_prefetch(B + (k + {{block_k}}) * ldb + col * VLEN, _MM_HINT_T0);
645
+ }
646
+ vb[col] = at::vec::convert<float>(b32);
647
+ {%- else %}
648
+ vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN);
649
+ {%- endif %}
650
+ {%- endif %}
651
+
652
+ }
653
+
654
+ constexpr int idx = row * COLS + col;
655
+ {%- if tail_n %}
656
+ if (col < rCOLS) {
657
+ vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]);
658
+ }
659
+ {%- else %}
660
+ vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]);
661
+ {%- endif %}
662
+ };
663
+
664
+ for (int k = 0; k < K; ++k) {
665
+ c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
666
+ }
667
+
668
+ // store to C
669
+ auto storec = [&](auto i) {
670
+ constexpr int row = i / COLS;
671
+ constexpr int col = i % COLS;
672
+ {%- if tail_n %}
673
+ int store_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN;
674
+ if (col < rCOLS) {
675
+ vc[i].store(C + row * ldc + col * VLEN, store_size);
676
+ }
677
+ {%- else %}
678
+ vc[i].store(C + row * ldc + col * VLEN);
679
+ {%- endif %}
680
+ };
681
+ c10::ForcedUnroll<ROWS * COLS>{}(storec);
682
+
683
+ {%- else %}
684
+ // Use 2 implementations for the transposed B:
685
+ // First implementation:
686
+ // Transpose first and then perform outer product calculation in sub-blocks,
687
+ // which introduces an additional transpose overhead of [K, N] compared to the non-transpose version.
688
+ // Second implementation:
689
+ // Directly perform inner product calculation in sub-blocks,
690
+ // which introduces an additional vector reduction of [M, N] compared to the non-tranpose version.
691
+ // Therefore, when M * N / (K * N) is large, the first implementation has better performance.
692
+ {%- if tail_n %}
693
+ if (K % Vectorized::size() == 0 && N % Vectorized::size() == 0 && 24 * BLOCK_M > K) {
694
+ {%- else %}
695
+ if (K % Vectorized::size() == 0 && 24 * BLOCK_M > K) {
696
+ {%- endif %}
697
+ // First implementation:
698
+ constexpr auto VLEN = Vectorized::size();
699
+ constexpr auto ROWS = BLOCK_M;
700
+ constexpr auto COLS = BLOCK_N / VLEN;
701
+ int _K = K / VLEN;
702
+ Vectorized va;
703
+ at::vec::VectorizedN<{{compute_t}}, VLEN> vb;
704
+ at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc;
705
+ auto loadc = [&](auto i) {
706
+ if constexpr (accum) {
707
+ constexpr int row = i / COLS;
708
+ constexpr int col = i % COLS;
709
+ vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN);
710
+ } else {
711
+ vc[i] = Vectorized(0.0f);
712
+ }
713
+ };
714
+ c10::ForcedUnroll<ROWS * COLS>{}(loadc);
715
+ auto unroll_loadB = [&](auto i, const {{input2_t}}* {{restrict_keyword}} src_ptr) {
716
+ {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
717
+ auto b = VectorizedIn::loadu(src_ptr + i * ldb, VLEN);
718
+ vb[i] = at::vec::convert<{{compute_t}}>(b);
719
+ {%- elif input2_dtype == torch.int8 %}
720
+ auto b32 = at::vec::convert_to_int32<int8_t>(src_ptr + i * ldb, VLEN);
721
+ vb[i] = at::vec::convert<float>(b32);
722
+ {%- else %}
723
+ vb[i] = Vectorized::loadu(src_ptr + i * ldb, VLEN);
724
+ {%- endif %}
725
+ };
726
+ auto compute_trans = [&, COLS](auto i, int k) {
727
+ constexpr int row = i % ROWS;
728
+ constexpr int col = i / ROWS;
729
+ constexpr int e_col = col * VLEN;
730
+ int idk = k * VLEN;
731
+ if constexpr (row == 0) {
732
+ c10::ForcedUnroll<VLEN>{}(unroll_loadB, B + e_col * ldb + idk);
733
+ at::vec::transpose_block(vb);
734
+ }
735
+ constexpr int idx = row * COLS + col;
736
+ {{kernel.unroll_pragma(16)}}
737
+ for (int j = 0; j < VLEN; j++) {
738
+ {%- if alpha != 1 %}
739
+ va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + idk + j]) * {{alpha}});
740
+ {%- else %}
741
+ va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + idk + j]));
742
+ {%- endif %}
743
+ vc[idx] = at::vec::fmadd(va, vb[j], vc[idx]);
744
+ }
745
+ };
746
+ for (int k = 0; k < _K; ++k) {
747
+ c10::ForcedUnroll<ROWS * COLS>{}(compute_trans, k);
748
+ }
749
+ // store to C
750
+ auto storec = [&](auto i) {
751
+ constexpr int row = i / COLS;
752
+ constexpr int col = i % COLS;
753
+ vc[i].store(C + row * ldc + col * VLEN);
754
+ };
755
+ c10::ForcedUnroll<ROWS * COLS>{}(storec);
756
+ } else {
757
+ // Second implementation
758
+ {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
759
+ constexpr auto VLEN = VectorizedIn::size();
760
+ {%- else %}
761
+ constexpr auto VLEN = Vectorized::size();
762
+ {%- endif %}
763
+ int _K = (K + VLEN - 1) / VLEN;
764
+ // sub-block size of BLOCK_N and BLOCK_M
765
+ constexpr int sM = {{sub_block_m}};
766
+ constexpr int sN = {{sub_block_n}};
767
+ {%- if tail_n %}
768
+ int bN = (N + sN - 1) / sN;
769
+ {%- else %}
770
+ constexpr int bN = (BLOCK_N + sN - 1) / sN;
771
+ {%- endif %}
772
+ constexpr int bM = (BLOCK_M + sM - 1) / sM;
773
+
774
+ {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
775
+ at::vec::VectorizedN<{{compute_t}}, 2> va;
776
+ at::vec::VectorizedN<{{compute_t}}, 2 * sN> vb;
777
+ {%- else %}
778
+ at::vec::Vectorized<{{compute_t}}> va;
779
+ at::vec::VectorizedN<{{compute_t}}, sN> vb;
780
+ {%- endif %}
781
+ at::vec::VectorizedN<{{compute_t}}, sN * sM> vmid;
782
+
783
+ {%- if tail_n %}
784
+ int ntail = N % sN;
785
+ {%- else %}
786
+ constexpr int ntail = BLOCK_N % sN;
787
+ {%- endif %}
788
+ constexpr int mtail = BLOCK_M % sM;
789
+ int ktail = K % VLEN;
790
+
791
+ auto compute_trans = [&](int m, int n, int k) {
792
+ {%- if tail_n %}
793
+ int e_n = (n == bN - 1 && ntail != 0) ? (N - n * sN) : sN;
794
+ {%- else %}
795
+ int e_n = (n == bN - 1 && ntail != 0) ? (BLOCK_N - n * sN) : sN;
796
+ {%- endif %}
797
+ int e_m = (m == bM - 1 && mtail != 0) ? (BLOCK_M - m * sM) : sM;
798
+ int e_k = (k == _K - 1 && ktail != 0) ? (K - k * VLEN) : VLEN;
799
+ {{kernel.unroll_pragma(sub_block_n)}}
800
+ for (int i = 0; i < e_n; i++) {
801
+ {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
802
+ auto b = VectorizedIn::loadu(B + (sN * n + i) * ldb + k * VLEN, e_k);
803
+ std::tie(vb[2 * i], vb[2 * i + 1]) = at::vec::convert_to_float<{{input_t}}>(b);
804
+ {%- elif input2_dtype == torch.int8 %}
805
+ auto b32 = at::vec::convert_to_int32<int8_t>(B + (sN * n + i) * ldb + k * VLEN, e_k);
806
+ vb[i] = at::vec::convert<float>(b32);
807
+ {%- else %}
808
+ vb[i] = Vectorized::loadu(B + (sN * n + i) * ldb + k * VLEN, e_k);
809
+ {%- endif %}
810
+ }
811
+
812
+ {{kernel.unroll_pragma(sub_block_m)}}
813
+ for (int s = 0; s < e_m; s++) {
814
+ {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
815
+ auto a = VectorizedIn::loadu(A + (sM * m + s) * lda + k * VLEN, e_k);
816
+ std::tie(va[0], va[1]) = at::vec::convert_to_float<{{input_t}}>(a);
817
+ {%- elif input2_dtype == torch.int8 %}
818
+ auto a32 = at::vec::convert_to_int32<int8_t>(A + (sM * m + s) * lda + k * VLEN, e_k);
819
+ va = at::vec::convert<float>(a32);
820
+ {%- else %}
821
+ va = Vectorized::loadu(A + (sM * m + s) * lda + k * VLEN, e_k);
822
+ {%- endif %}
823
+
824
+ {%- if alpha != 1 %}
825
+ va = va * Vectorized({{alpha}});
826
+ {%- endif %}
827
+ if (k == 0) {
828
+ {{kernel.unroll_pragma(sub_block_n)}}
829
+ for (int i = 0; i < e_n; i++) {
830
+ {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
831
+ vmid[sN * s + i] = at::vec::fmadd(va[0], vb[2 * i], Vectorized(0.0f));
832
+ vmid[sN * s + i] = at::vec::fmadd(va[1], vb[2 * i + 1], vmid[sN * s + i]);
833
+ {%- else %}
834
+ vmid[sN * s + i] = at::vec::fmadd(va, vb[i], Vectorized(0.0f));
835
+ {%- endif %}
836
+ }
837
+ } else {
838
+ {{kernel.unroll_pragma(sub_block_n)}}
839
+ for (int i = 0; i < e_n; i++) {
840
+ {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
841
+ vmid[sN * s + i] = at::vec::fmadd(va[0], vb[2 * i], vmid[sN * s + i]);
842
+ vmid[sN * s + i] = at::vec::fmadd(va[1], vb[2 * i + 1], vmid[sN * s + i]);
843
+ {%- else %}
844
+ vmid[sN * s + i] = at::vec::fmadd(va, vb[i], vmid[sN * s + i]);
845
+ {%- endif %}
846
+ }
847
+ }
848
+ }
849
+
850
+ // store to C
851
+ if (k == _K - 1) {
852
+ {{kernel.unroll_pragma(sub_block_m)}}
853
+ for (int s = 0; s < e_m; s++) {
854
+ {{kernel.unroll_pragma(sub_block_n)}}
855
+ for (int i = 0; i < e_n; i++) {
856
+ auto v = at::vec::vec_reduce_all([](Vectorized& x, Vectorized& y) { return x + y; }, vmid[sN * s + i]);
857
+ if constexpr (accum) {
858
+ auto c = *(C + (sM * m + s) * ldc + sN * n + i);
859
+ *(C + (sM * m + s) * ldc + sN * n + i) = c + v;
860
+ } else {
861
+ *(C + (sM * m + s) * ldc + sN * n + i) = v;
862
+ }
863
+ }
864
+ }
865
+ }
866
+ };
867
+
868
+ for (int n = 0; n < bN; ++n) {
869
+ for (int m = 0; m < bM; ++m) {
870
+ for (int k = 0; k < _K; ++k) {
871
+ compute_trans(m, n, k);
872
+ }
873
+ }
874
+ }
875
+ }
876
+ {%- endif %}
877
+ }
878
+ """
879
+
880
+ # set trans_b to generate gemm that supports transposed B matrix
881
+ # set tail_n to support the tail of N
882
+ # TODO add trans_b support for other micro gemms
883
+ # and move setting of trans_b to the init of CppMicroGemm
884
+ def __init__(
885
+ self,
886
+ name,
887
+ input_dtype,
888
+ input2_dtype,
889
+ output_dtype,
890
+ compute_dtype,
891
+ register_blocking,
892
+ alpha=1,
893
+ tail_n=False,
894
+ trans_b=False,
895
+ ) -> None:
896
+ super().__init__(
897
+ name,
898
+ input_dtype,
899
+ input2_dtype,
900
+ output_dtype,
901
+ compute_dtype,
902
+ register_blocking,
903
+ alpha,
904
+ )
905
+ self.tail_n = tail_n
906
+ # trans_b is only supported on platforms that
907
+ # support avx512 or avx2 since transpose_block is
908
+ # only implemented on these platforms
909
+ if trans_b:
910
+ vec_isa = pick_vec_isa()
911
+ assert issubclass(vec_isa.__class__, VecAVX512) or issubclass(
912
+ vec_isa.__class__, VecAVX2
913
+ )
914
+ self.trans_b = trans_b
915
+
916
+ def codegen_define(self, kernel: CppTemplateKernel) -> str:
917
+ options = {
918
+ "declare_kernel": self.get_kernel_declaration(),
919
+ "kernel": kernel,
920
+ "block_m": self.register_blocking.block_m,
921
+ "block_n": self.register_blocking.block_n,
922
+ "block_k": self.register_blocking.block_k,
923
+ "trans_b": False,
924
+ "tail_n": False,
925
+ "restrict_keyword": get_restrict_keyword(),
926
+ **self.get_common_options(),
927
+ }
928
+ if self.trans_b:
929
+ # TODO supports tuning of sub_block_m/sub_block_n
930
+ # to get better performance for specific shapes
931
+ sub_block_m = min(1, self.register_blocking.block_m)
932
+ sub_block_n = min(4, self.register_blocking.block_n)
933
+ # update options to generate kernel with trans_b and sub-block size
934
+ options.update(
935
+ {
936
+ "trans_b": self.trans_b,
937
+ "sub_block_m": sub_block_m,
938
+ "sub_block_n": sub_block_n,
939
+ }
940
+ )
941
+ result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
942
+ options
943
+ )
944
+ # update options to generate the kernel for the tail of N
945
+ if self.tail_n:
946
+ options.update(
947
+ {
948
+ "tail_n": self.tail_n,
949
+ }
950
+ )
951
+ result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
952
+ options
953
+ )
954
+ result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
955
+ options
956
+ )
957
+ return result
958
+
959
+
960
+ # extra check for CppMicroGemmAMX
961
+ def check_amx_extra(config, m, n, k, alpha, num_threads, **kwargs):
962
+ vnni_size = 4 if config.input_dtype in [torch.uint8, torch.int8] else 2
963
+ return k % vnni_size == 0 and alpha == 1
964
+
965
+
966
+ @register_micro_gemm(
967
+ *generate_gemm_config(
968
+ VecAMX,
969
+ [(32, 32, 64), (48, 16, 64)],
970
+ input_dtype=torch.int8,
971
+ input2_dtype=torch.int8,
972
+ output_dtype=torch.int32,
973
+ compute_dtype=torch.int32,
974
+ extra_check=check_amx_extra,
975
+ ),
976
+ *generate_gemm_config(
977
+ VecAMX,
978
+ [(32, 32, 32), (48, 16, 32), (16, 48, 32)],
979
+ input_dtype=torch.bfloat16,
980
+ input2_dtype=torch.int8,
981
+ output_dtype=torch.float,
982
+ compute_dtype=torch.float,
983
+ extra_check=check_amx_extra,
984
+ ),
985
+ *generate_gemm_config(
986
+ VecAMX,
987
+ [(32, 32, 32), (48, 16, 32), (16, 48, 32)],
988
+ input_dtype=torch.bfloat16,
989
+ output_dtype=torch.float,
990
+ extra_check=check_amx_extra,
991
+ ),
992
+ *generate_gemm_config(
993
+ VecAMX,
994
+ [(32, 32, 64), (48, 16, 64)],
995
+ input_dtype=torch.uint8,
996
+ input2_dtype=torch.int8,
997
+ output_dtype=torch.int32,
998
+ compute_dtype=torch.int32,
999
+ extra_check=check_amx_extra,
1000
+ ),
1001
+ )
1002
+ class CppMicroGemmAMX(CppMicroGemm):
1003
+ """
1004
+ This class generates the code for micro gemm using Advanced Matrix extension (AMX)
1005
+ instructions available in 4th generation Intel Xeon for compute.
1006
+ It supports input types of torch.bfloat16 with fp32 output.
1007
+ """
1008
+
1009
+ TEMPLATE_ENTRY = r"""
1010
+ {{declare_kernel}} {
1011
+ {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
1012
+ {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2");
1013
+ {%- if pack_vnni_B_locally %}
1014
+ {{template.codegen_allocate_weight_buffer("packed_B_buf", input2_t, "K", block_n)}}
1015
+ {%- endif %}
1016
+ {%- if use_cached_dequantized_B %}
1017
+ // Create a stack-allocated buffer for tiles of B.
1018
+ // Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements.
1019
+ // we cache K * {{block_n}} elements of dequantized B
1020
+ {{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}}
1021
+ const auto buf_size = K * {{block_n}};
1022
+ auto load_dequantized_B = [&](int base_idx) {
1023
+ // Load a tile of B & cache it in L1D.
1024
+ {{input2_t}}* base_addr = const_cast<{{input2_t}}*>(B) + base_idx;
1025
+ for (int idx_dq = 0, idx_q = 0; idx_dq < buf_size; idx_q += ldb, idx_dq += {{block_n}}) {
1026
+ {%- for vec_idx in range(0, block_n, 32) %}
1027
+ {%- if (block_n - vec_idx) >= 32 %}
1028
+ auto b_int8_idx_{{vec_idx}} = at::vec::Vectorized<int8_t>::loadu(
1029
+ base_addr + idx_q + {{vec_idx}} ,
1030
+ static_cast<int64_t>(32)
1031
+ );
1032
+ auto b_bf16_idx_{{vec_idx}} = at::vec::convert<{{input_t}}>(b_int8_idx_{{vec_idx}});
1033
+ b_bf16_idx_{{vec_idx}}.store(dequantized_B_buf + idx_dq + {{vec_idx}});
1034
+ {%- else %}
1035
+ auto b_int8_tail = at::vec::Vectorized<int8_t>::loadu(
1036
+ base_addr + idx_q + {{block_n - (block_n % 32)}},
1037
+ static_cast<int64_t>({{block_n % 32}})
1038
+ );
1039
+ auto b_bf16_tail = at::vec::convert<{{input_t}}>(b_int8_tail);
1040
+ b_bf16_tail.store(
1041
+ dequantized_B_buf + idx_dq + {{block_n - (block_n % 32)}},
1042
+ static_cast<int64_t>({{block_n % 32}})
1043
+ );
1044
+ {%- endif %}
1045
+ {%- endfor %}
1046
+ }
1047
+ };
1048
+ {%- endif %}
1049
+ // The ldb would not be block_n if N != block_n
1050
+ {%- if use_cached_dequantized_B or pack_vnni_B_locally %}
1051
+ const int64_t updated_ldb = {{block_n}};
1052
+ {%- else %}
1053
+ const int64_t updated_ldb = ldb;
1054
+ {%- endif %}
1055
+ // TODO(jgong5): loop unroll for M and N
1056
+ for (int64_t n = 0; n < N; n += {{block_n}}) {
1057
+ {%- if pack_vnni_B_locally %}
1058
+ // Pack non-constant weights into VNNI interleaved format in packed_B_buf
1059
+ at::vec::pack_vnni2(B + n, packed_B_buf, ldb, K, {{block_n}});
1060
+ {%- elif use_cached_dequantized_B %}
1061
+ // Dequantize K * block_n int8 B elements into BF16
1062
+ load_dequantized_B(n);
1063
+ {%- endif %}
1064
+ for (int64_t m = 0; m < M; m += {{block_m}}) {
1065
+ int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
1066
+ int64_t m_tail = m;
1067
+ {%- for num_rows in range(block_m, 0, -16) %}
1068
+ {%- if num_rows != block_m %}
1069
+ else
1070
+ {%- endif %}
1071
+ if (block_m >= {{num_rows}}) {
1072
+ {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}<accum>(
1073
+ amx_state,
1074
+ A + m * lda,
1075
+ {%- if use_cached_dequantized_B %}
1076
+ dequantized_B_buf,
1077
+ {%- elif pack_vnni_B_locally %}
1078
+ packed_B_buf,
1079
+ {%- else %}
1080
+ B + n,
1081
+ {%- endif %}
1082
+ C + m * ldc + n,
1083
+ K,
1084
+ lda,
1085
+ updated_ldb,
1086
+ ldc,
1087
+ 16
1088
+ );
1089
+ block_m -= {{num_rows}};
1090
+ m_tail += {{num_rows}};
1091
+ }
1092
+ {%- endfor %}
1093
+ if (block_m > 0) {
1094
+ {{kernel_name}}_amx_kernel_16_{{num_columns}}<accum>(
1095
+ amx_state,
1096
+ A + m_tail * lda,
1097
+ {%- if use_cached_dequantized_B %}
1098
+ dequantized_B_buf,
1099
+ {%- elif pack_vnni_B_locally %}
1100
+ packed_B_buf,
1101
+ {%- else %}
1102
+ B + n,
1103
+ {%- endif %}
1104
+ C + m_tail * ldc + n,
1105
+ K,
1106
+ lda,
1107
+ updated_ldb,
1108
+ ldc,
1109
+ block_m
1110
+ );
1111
+ }
1112
+ }
1113
+ }
1114
+ }
1115
+ """
1116
+
1117
+ TEMPLATE_KERNEL = r"""
1118
+
1119
+ template <bool accum, bool prefetch=false>
1120
+ inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}(
1121
+ AMXState& amx_state,
1122
+ const {{input_t}}* {{restrict_keyword}} A,
1123
+ {%- if use_cached_dequantized_B %}
1124
+ const {{input_t}}* {{restrict_keyword}} B,
1125
+ {%- else %}
1126
+ const {{input2_t}}* {{restrict_keyword}} B,
1127
+ {%- endif %}
1128
+ {{output_t}}* {{restrict_keyword}} C,
1129
+ int64_t K,
1130
+ int64_t lda,
1131
+ int64_t ldb,
1132
+ int64_t ldc,
1133
+ uint8_t tilecfg_rows
1134
+ ) {
1135
+ // TODO(jgong5): add prefetch hint for A, B, C
1136
+ auto loadconfig = [](const amx_tilecfg& cfg) {
1137
+ _tile_loadconfig(&cfg);
1138
+ };
1139
+ const auto last_k_offset = K / {{block_k}} * {{block_k}};
1140
+ const auto tail_k_size = K - last_k_offset;
1141
+ if C10_LIKELY (last_k_offset > 0) {
1142
+ amx_state.configure(tilecfg_rows, 64, {{num_rows}} / 16, {{num_columns}}, loadconfig);
1143
+ } else {
1144
+ amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig);
1145
+ }
1146
+ auto load_c = [&]() {
1147
+ {%- for tile_row in range(num_rows // 16) %}
1148
+ {%- for tile_col in range(num_columns) %}
1149
+ {%- set tile_idx = tile_row * num_columns + tile_col %}
1150
+ _tile_loadd({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}}));
1151
+ {%- endfor %}
1152
+ {%- endfor %}
1153
+ };
1154
+ auto zero_c = [&]() {
1155
+ {%- for tile_row in range(num_rows // 16) %}
1156
+ {%- for tile_col in range(num_columns) %}
1157
+ {%- set tile_idx = tile_row * num_columns + tile_col %}
1158
+ _tile_zero({{tile_idx}});
1159
+ {%- endfor %}
1160
+ {%- endfor %}
1161
+ };
1162
+
1163
+ if constexpr (accum) {
1164
+ load_c();
1165
+ } else {
1166
+ zero_c();
1167
+ }
1168
+
1169
+ auto compute = [&](int k) {
1170
+ {%- set tile_offset_a = num_rows // 16 * num_columns %}
1171
+ {%- set tile_offset_b = tile_offset_a + num_rows // 16 %}
1172
+ {%- for tile_row in range(num_rows // 16) %}
1173
+ {%- for tile_col in range(num_columns) %}
1174
+ {%- set tile_idx_a = tile_offset_a + tile_row %}
1175
+ {%- set tile_idx_b = tile_offset_b + tile_col %}
1176
+ {%- set tile_idx_c = tile_row * num_columns + tile_col %}
1177
+ {%- if tile_col == 0 %}
1178
+ _tile_stream_loadd({{tile_idx_a}}, A + {{tile_row * 16}} * lda + k, lda * sizeof({{input_t}}));
1179
+ {%- endif %}
1180
+ {%- if tile_row == 0 %}
1181
+ _tile_loadd({{tile_idx_b}}, B + k * ldb + {{tile_col * 16 * vnni_size}}, ldb * {{vnni_size}} * sizeof({{input_t}}));
1182
+ {%- endif %}
1183
+ {%- if int8_gemm %}
1184
+ {%- if input_dtype == torch.int8 %}
1185
+ _tile_dpbssd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
1186
+ {%- else %}
1187
+ _tile_dpbusd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
1188
+ {%- endif %}
1189
+ {%- else %}
1190
+ _tile_dpbf16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
1191
+ {%- endif %}
1192
+ {%- endfor %}
1193
+ {%- endfor %}
1194
+ };
1195
+
1196
+ {{kernel.unroll_pragma(4)}}
1197
+ for (int k = 0; k < last_k_offset; k += {{block_k}}) {
1198
+ compute(k);
1199
+ }
1200
+
1201
+ auto store_c = [&]() {
1202
+ // store to C
1203
+ {%- for tile_row in range(num_rows // 16) %}
1204
+ {%- for tile_col in range(num_columns) %}
1205
+ {%- set tile_idx = tile_row * num_columns + tile_col %}
1206
+ _tile_stored({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}}));
1207
+ {%- endfor %}
1208
+ {%- endfor %}
1209
+ };
1210
+
1211
+ // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
1212
+ if C10_UNLIKELY (tail_k_size > 0) {
1213
+ if C10_LIKELY (last_k_offset > 0) {
1214
+ store_c();
1215
+ amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig);
1216
+ load_c();
1217
+ }
1218
+ compute(last_k_offset);
1219
+ }
1220
+
1221
+ store_c();
1222
+ }
1223
+ """
1224
+
1225
+ def codegen_define(self, kernel: CppTemplateKernel) -> str:
1226
+ block_m, block_n, block_k = self.register_blocking
1227
+ assert block_m % 16 == 0, "Only support block_m % 16 == 0 for AMX"
1228
+ assert block_n % 16 == 0, "Only support block_n % 16 == 0 for AMX"
1229
+ if self.input_dtype in [torch.uint8, torch.int8]:
1230
+ assert block_k == 64, "Only support block_k = 64 for AMX INT8"
1231
+ else:
1232
+ assert block_k == 32, "Only support block_k = 32 for AMX Bfloat16/Float16"
1233
+ num_columns = block_n // 16
1234
+ options = {
1235
+ "declare_kernel": self.get_kernel_declaration(),
1236
+ "use_cached_dequantized_B": (
1237
+ self.input_dtype == torch.bfloat16
1238
+ and self.input2_dtype in [torch.int8, torch.uint8]
1239
+ ),
1240
+ "kernel": kernel,
1241
+ "block_m": block_m,
1242
+ "block_n": block_n,
1243
+ "block_k": block_k,
1244
+ "num_columns": num_columns,
1245
+ "restrict_keyword": get_restrict_keyword(),
1246
+ **self.get_common_options(),
1247
+ }
1248
+ result = ""
1249
+ for num_rows in range(block_m, 0, -16):
1250
+ amx_kernel_options = {**options, "num_rows": num_rows}
1251
+ result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
1252
+ amx_kernel_options
1253
+ )
1254
+ result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
1255
+ options
1256
+ )
1257
+ return result
1258
+
1259
+ def codegen_init(
1260
+ self,
1261
+ kernel: CppTemplateKernel,
1262
+ ) -> str:
1263
+ return "AMXState amx_state;"
1264
+
1265
+ def codegen_finalize(
1266
+ self,
1267
+ kernel: CppTemplateKernel,
1268
+ ) -> str:
1269
+ return "amx_state.release([]() { _tile_release(); });"
1270
+
1271
+ def get_kernel_extra_args_declare(self) -> str:
1272
+ return "AMXState& amx_state,"
1273
+
1274
+ def get_kernel_extra_args(self, **kwargs) -> list[str]:
1275
+ return ["amx_state,"]
1276
+
1277
+ def get_b_layout(self):
1278
+ if self.input_dtype in [torch.uint8, torch.int8]:
1279
+ return LayoutType.VNNI4
1280
+ else:
1281
+ return LayoutType.VNNI2
1282
+
1283
+
1284
+ # extra check for CppMicroBrgemm
1285
+ def check_brgemm_extra(config, m, n, k, alpha, num_threads, **kwargs):
1286
+ assert config.input_dtype == torch.half and config.output_dtype == torch.float
1287
+ vnni_size = 2
1288
+ # use brgemm for Half when amx_fp16 is supported
1289
+ return torch.cpu._is_amx_fp16_supported() and k % vnni_size == 0 and alpha == 1
1290
+
1291
+
1292
+ @register_micro_gemm(
1293
+ *generate_gemm_config(
1294
+ VecAMX,
1295
+ [(32, 32, 32), (48, 16, 32), (16, 48, 32)],
1296
+ input_dtype=torch.half,
1297
+ output_dtype=torch.float,
1298
+ extra_check=check_brgemm_extra,
1299
+ ),
1300
+ )
1301
+ class CppMicroBrgemm(CppMicroGemm):
1302
+ """
1303
+ This class generates the code for micro gemm using oneDNN brgemm.
1304
+ It supports input types of torch.half.
1305
+ """
1306
+
1307
+ TEMPLATE_ENTRY = r"""
1308
+ #include <ATen/native/CPUBlas.h>
1309
+ {{declare_kernel}} {
1310
+ {%- if pack_vnni_B_locally %}
1311
+ {{template.codegen_allocate_weight_buffer("packed_B_buf", input2_t, "K * N")}}
1312
+ at::vec::pack_vnni2(B, packed_B_buf, ldb, K, N);
1313
+ {%- endif %}
1314
+ at::native::cpublas::brgemm(
1315
+ M, N, K,
1316
+ {%- if pack_vnni_B_locally %}
1317
+ lda, N, ldc,
1318
+ {%- else %}
1319
+ lda, ldb, ldc,
1320
+ {%- endif %}
1321
+ accum,
1322
+ A,
1323
+ {%- if pack_vnni_B_locally %}
1324
+ packed_B_buf,
1325
+ {%- else %}
1326
+ B,
1327
+ {%- endif %}
1328
+ C);
1329
+ }
1330
+ """
1331
+
1332
+ def codegen_define(self, kernel: CppTemplateKernel) -> str:
1333
+ options = {
1334
+ "declare_kernel": self.get_kernel_declaration(),
1335
+ "kernel": kernel,
1336
+ "block_m": self.register_blocking.block_m,
1337
+ "block_n": self.register_blocking.block_n,
1338
+ "block_k": self.register_blocking.block_k,
1339
+ "restrict_keyword": get_restrict_keyword(),
1340
+ **self.get_common_options(),
1341
+ }
1342
+ result = ""
1343
+ result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
1344
+ options
1345
+ )
1346
+ return result
1347
+
1348
+ def codegen_finalize(
1349
+ self,
1350
+ kernel: CppTemplateKernel,
1351
+ ) -> str:
1352
+ return "at::native::cpublas::brgemm_release();"
1353
+
1354
+ def get_b_layout(self):
1355
+ assert self.input_dtype == torch.half and torch.cpu._is_amx_fp16_supported()
1356
+ return LayoutType.VNNI2
1357
+
1358
+
1359
+ def check_woq_int4_extra(config, m, n, k, alpha, num_threads, **kwargs):
1360
+ if alpha != 1:
1361
+ return False
1362
+ q_group_size = kwargs.get("q_group_size", None)
1363
+ assert q_group_size is not None
1364
+ if (
1365
+ q_group_size < 32
1366
+ or k % q_group_size != 0
1367
+ or config.register_blocking.block_k > q_group_size
1368
+ ):
1369
+ return False
1370
+ return k % config.register_blocking.block_k == 0 and n % 64 == 0
1371
+
1372
+
1373
+ @register_micro_gemm(
1374
+ # TODO: support float/half input
1375
+ *generate_gemm_config(
1376
+ VecAVX512,
1377
+ [(4, 64, 32), (4, 64, 64), (4, 64, 128)],
1378
+ input_dtype=torch.bfloat16,
1379
+ input2_dtype=torch.uint8,
1380
+ output_dtype=torch.float,
1381
+ compute_dtype=torch.float,
1382
+ extra_check=check_woq_int4_extra,
1383
+ ),
1384
+ )
1385
+ class CppMicroGemmWoQInt4Avx512(CppMicroGemmFP32Vec):
1386
+ """
1387
+ This class generates the code for WoQ int4 micro gemm using AVX512 intrinsics.
1388
+ It is based on the corresponding ATen kernel.
1389
+ Shape of packed weight = [N // 64, K, 32], viewed as [N, K // 2]
1390
+ Shape of packed ScalesAndZeros = [K // group_size, N, 2]
1391
+ """
1392
+
1393
+ TEMPLATE_ENTRY = r"""
1394
+ {{declare_kernel}} {
1395
+ {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
1396
+ {{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
1397
+ auto group_size = q_group_size;
1398
+ for (int64_t m = 0; m < M; m += {{block_m}}) {
1399
+ int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
1400
+ for (int64_t n = 0; n < N; n += {{block_n}}) {
1401
+ if (block_m == {{block_m}}) {
1402
+ {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>(
1403
+ A + m * lda,
1404
+ reinterpret_cast<const uint8_t*>(B) + n * ldb,
1405
+ C + m * ldc + n,
1406
+ K,
1407
+ lda,
1408
+ /* ldb */ {{block_n}} / 2,
1409
+ ldc,
1410
+ group_size,
1411
+ ScaleAndZeros + n * 2,
1412
+ lds,
1413
+ k_start
1414
+ );
1415
+ } else {
1416
+ switch (block_m) {
1417
+ {%- for b in range(block_m - 1, 0, -1) %}
1418
+ case {{b}}:
1419
+ {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>(
1420
+ A + m * lda,
1421
+ reinterpret_cast<const uint8_t*>(B) + n * ldb,
1422
+ C + m * ldc + n,
1423
+ K,
1424
+ lda,
1425
+ /* ldb */ {{block_n}} / 2,
1426
+ ldc,
1427
+ group_size,
1428
+ ScaleAndZeros + n * 2,
1429
+ lds,
1430
+ k_start
1431
+ );
1432
+ break;
1433
+ {%- endfor %}
1434
+ default:
1435
+ {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m);
1436
+ }
1437
+ }
1438
+ }
1439
+ }
1440
+ }
1441
+ """
1442
+
1443
+ TEMPLATE_KERNEL = r"""
1444
+ inline bool {{kernel_name}}_is_block_start(int index, int k_start, int group_size) {
1445
+ return (k_start + index) % group_size == 0;
1446
+ }
1447
+
1448
+ inline __m128i {{kernel_name}}_convert_int4_to_int8(const uint8_t* data) {
1449
+ __m128i tmp = _mm_loadu_si64((const __m128i*)data);
1450
+ __m128i bytes = _mm_cvtepu8_epi16(tmp);
1451
+ const __m128i lowMask = _mm_set1_epi8(0xF);
1452
+ __m128i high = _mm_andnot_si128(lowMask, bytes);
1453
+ __m128i low = _mm_and_si128(lowMask, bytes);
1454
+ high = _mm_slli_epi16(high, 4);
1455
+ bytes = _mm_or_si128(low, high);
1456
+ return bytes;
1457
+ }
1458
+
1459
+ template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum>
1460
+ inline void {{kernel_name}}_kernel(
1461
+ const {{input_t}}* {{restrict_keyword}} A,
1462
+ const uint8_t* {{restrict_keyword}} B,
1463
+ {{output_t}}* {{restrict_keyword}} C,
1464
+ int64_t K,
1465
+ int64_t lda,
1466
+ int64_t ldb,
1467
+ int64_t ldc,
1468
+ int64_t q_group_size,
1469
+ const at::BFloat16* {{restrict_keyword}} ScaleAndZeros,
1470
+ int64_t lds, // leading dimension of ScaleAndZeros
1471
+ int64_t k_start) {
1472
+ constexpr int BLOCK_K = {{block_k}};
1473
+ constexpr int ROWS = BLOCK_M;
1474
+ constexpr int COLS = BLOCK_N / 16;
1475
+
1476
+ const int PREFETCH_SIZE_K = 16 * 4;
1477
+ const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
1478
+
1479
+ // number of blocks on K
1480
+ const int KB = K / BLOCK_K;
1481
+
1482
+ __m512 va;
1483
+ __m512 vb[COLS];
1484
+ __m512 vc[ROWS * COLS];
1485
+ __m512 scale[COLS];
1486
+ __m512 zero[COLS];
1487
+
1488
+ // Lookup table to de-quantize int4 values to bf16.
1489
+ // Values are dequantized as truly int4 [-8, 7] range;
1490
+ //
1491
+ // dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
1492
+ //
1493
+ static const __m512 lut = _mm512_set_ps(
1494
+ 7.0f, 6.0f, 5.0f, 4.0f,
1495
+ 3.0f, 2.0f, 1.0f, 0.0f,
1496
+ -1.0f, -2.0f, -3.0f, -4.0f,
1497
+ -5.0f, -6.0f, -7.0f, -8.0f);
1498
+
1499
+ // index for transpose
1500
+ static const __m512i idx1 = _mm512_set_epi32(
1501
+ 30, 28, 26, 24, 22, 20, 18, 16,
1502
+ 14, 12, 10, 8, 6, 4, 2, 0);
1503
+ static const __m512i idx2 = _mm512_set_epi32(
1504
+ 31, 29, 27, 25, 23, 21, 19, 17,
1505
+ 15, 13, 11, 9, 7, 5, 3, 1);
1506
+
1507
+ // load scale and zero point
1508
+ auto load_scale_and_zeros = [&](int i, int _kb) {
1509
+ // load 2x bfloat16 vector
1510
+ __m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i));
1511
+ if (_kb + PREFETCH_SIZE_KB < KB) {
1512
+ _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0);
1513
+ }
1514
+
1515
+ // convert to 2x f32 vector
1516
+ __m512 a, b;
1517
+ at::vec::cvtbf16_fp32(t, a, b);
1518
+
1519
+ // transpose scale_and_zero from {16, 2} to {2, 16}
1520
+ // inputs:
1521
+ // a: {s0, z0, s1, z1, ..., s7, z7}
1522
+ // b: {s8, z8, s9, z9, ..., s15, z15}
1523
+ // output:
1524
+ // scale: {s0, s1, s2, ..., s15}
1525
+ // zero: {z0, z1, z2, ..., z15}
1526
+ scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
1527
+ zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
1528
+ };
1529
+
1530
+ auto loadc = [&](auto i) {
1531
+ if constexpr (accum) {
1532
+ constexpr int row = i / COLS;
1533
+ constexpr int col = i % COLS;
1534
+ vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16);
1535
+ } else {
1536
+ vc[i] = _mm512_setzero_ps();
1537
+ }
1538
+ };
1539
+ c10::ForcedUnroll<ROWS * COLS>{}(loadc);
1540
+
1541
+ auto compute = [&, COLS](auto i, int k) {
1542
+ constexpr int row = i / COLS;
1543
+ constexpr int col = i % COLS;
1544
+
1545
+ if constexpr (col == 0) {
1546
+ float aa = static_cast<float>(A[row * lda + k]);
1547
+ if (k + PREFETCH_SIZE_K < K) {
1548
+ _mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
1549
+ }
1550
+ va = _mm512_set1_ps(aa);
1551
+ }
1552
+
1553
+ if constexpr (row == 0) {
1554
+ if constexpr (COLS == 4) {
1555
+ // when BLOCK_N = 64, handle each row at a time
1556
+ // to reduce de-quantize overhead.
1557
+ if constexpr (col == 0) {
1558
+ __m256i b4 = _mm256_loadu_si256((__m256i*)(B + k * ldb));
1559
+ if (k + PREFETCH_SIZE_K < K) {
1560
+ _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0);
1561
+ }
1562
+
1563
+ __m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4));
1564
+ vb[0] = _mm512_permutexvar_ps(b32, lut);
1565
+ vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]);
1566
+ vb[2] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
1567
+ vb[2] = _mm512_fmadd_ps(vb[2], scale[2], zero[2]);
1568
+
1569
+ b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1));
1570
+ vb[1] = _mm512_permutexvar_ps(b32, lut);
1571
+ vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]);
1572
+ vb[3] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
1573
+ vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]);
1574
+ }
1575
+ } else {
1576
+ __m128i b8 = {{kernel_name}}_convert_int4_to_int8(B + k * ldb + col * 8);
1577
+ __m512i b32 = _mm512_cvtepu8_epi32(b8);
1578
+ vb[col] = _mm512_permutexvar_ps(b32, lut);
1579
+ vb[col] = _mm512_fmadd_ps(vb[col], scale[col], zero[col]);
1580
+ }
1581
+ }
1582
+
1583
+ constexpr int idx = row * COLS + col;
1584
+ vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
1585
+ };
1586
+
1587
+ for (int k = 0, kb = 0; k < K; ++k) {
1588
+ if ({{kernel_name}}_is_block_start(k, k_start, q_group_size)) {
1589
+ c10::ForcedUnroll<COLS>{}(load_scale_and_zeros, kb++);
1590
+ }
1591
+ c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
1592
+ }
1593
+
1594
+ //store to C
1595
+ auto storec = [&, COLS](auto i) {
1596
+ constexpr int row = i / COLS;
1597
+ constexpr int col = i % COLS;
1598
+ _mm512_storeu_ps(C + row * ldc + col * 16, vc[i]);
1599
+ };
1600
+ c10::ForcedUnroll<ROWS * COLS>{}(storec);
1601
+ }
1602
+ """
1603
+
1604
+ def get_kernel_extra_args_declare(self) -> str:
1605
+ return (
1606
+ "const int64_t q_group_size,\n"
1607
+ " const at::BFloat16* __restrict__ ScaleAndZeros,\n"
1608
+ " const int64_t lds,\n"
1609
+ " int64_t k_start,"
1610
+ )
1611
+
1612
+ def get_kernel_extra_args(self, **kwargs) -> list[str]:
1613
+ assert "kernel" in kwargs
1614
+ assert "qscale_and_zeros" in kwargs
1615
+ kernel = kwargs["kernel"]
1616
+ qscale_and_zeros = kwargs["qscale_and_zeros"]
1617
+ return [
1618
+ "group_size,",
1619
+ f"&({kernel.index(qscale_and_zeros, [0, 0, 0])}),",
1620
+ "N * 2,", # lds
1621
+ "k_start,",
1622
+ ]
1623
+
1624
+ def is_woq_int4(self):
1625
+ return True
1626
+
1627
+
1628
+ @register_micro_gemm(
1629
+ *generate_gemm_config(
1630
+ VecAMX,
1631
+ [ # (block_m, block_n, block_k)
1632
+ (16, 32, 32),
1633
+ (32, 32, 32),
1634
+ ],
1635
+ input_dtype=torch.bfloat16,
1636
+ input2_dtype=torch.uint8,
1637
+ output_dtype=torch.float,
1638
+ compute_dtype=torch.float,
1639
+ extra_check=check_amx_extra,
1640
+ ),
1641
+ )
1642
+ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX):
1643
+ """
1644
+ This class generates the code for WoQ int4 micro gemm using AMX intrinsics,
1645
+ which are available on 4th and newer generations of Intel Xeon.
1646
+ Shape of packed weight = [N // 32, K, 16], viewed as [N, K // 2]
1647
+ Shape of packed ScalesAndZeros = [K // group_size, N, 2]
1648
+ Reuse TEMPLATE_KERNEL of CppMicroGemmAMX.
1649
+ """
1650
+
1651
+ TEMPLATE_ENTRY = r"""
1652
+ inline bool {{kernel_name}}_is_block_start(int index, int k_start, int group_size) {
1653
+ return (k_start + index) % group_size == 0;
1654
+ }
1655
+
1656
+ {{declare_kernel}} {
1657
+ {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
1658
+ {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2");
1659
+ {{kernel.assert_function}}({{block_n}} == 32, "block_n must be 32 for WOQ int4");
1660
+
1661
+ // Create a stack-allocated buffer for tiles of B.
1662
+ // Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements.
1663
+ // we cache K * {{block_n}} elements of dequantized B
1664
+ {{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}}
1665
+
1666
+ constexpr int BLOCK_K = {{block_k}};
1667
+ constexpr int64_t BLOCK_N = {{block_n}};
1668
+ constexpr int COLS = BLOCK_N / 16;
1669
+ const int PREFETCH_SIZE_K = 16 * 4;
1670
+ const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
1671
+ const int KB = K / BLOCK_K;
1672
+
1673
+ __m512i b32[COLS * 2];
1674
+ __m512 vb[COLS * 2];
1675
+ __m512 scale[COLS];
1676
+ __m512 zero[COLS];
1677
+
1678
+ // Lookup table to de-quantize int4 values to bf16.
1679
+ // Values are dequantized as truly int4 [-8, 7] range;
1680
+ //
1681
+ // dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
1682
+ //
1683
+ static const __m512 lut = _mm512_set_ps(
1684
+ 7.0f, 6.0f, 5.0f, 4.0f,
1685
+ 3.0f, 2.0f, 1.0f, 0.0f,
1686
+ -1.0f, -2.0f, -3.0f, -4.0f,
1687
+ -5.0f, -6.0f, -7.0f, -8.0f);
1688
+
1689
+ // index for transpose
1690
+ static const __m512i idx1 = _mm512_set_epi32(
1691
+ 30, 28, 26, 24, 22, 20, 18, 16,
1692
+ 14, 12, 10, 8, 6, 4, 2, 0);
1693
+ static const __m512i idx2 = _mm512_set_epi32(
1694
+ 31, 29, 27, 25, 23, 21, 19, 17,
1695
+ 15, 13, 11, 9, 7, 5, 3, 1);
1696
+
1697
+ // Indices for VNNI layout conversion
1698
+ __m512i idx_low = _mm512_set_epi32(
1699
+ 0x17,
1700
+ 0x07,
1701
+ 0x16,
1702
+ 0x06,
1703
+ 0x15,
1704
+ 0x05,
1705
+ 0x14,
1706
+ 0x04,
1707
+ 0x13,
1708
+ 0x03,
1709
+ 0x12,
1710
+ 0x02,
1711
+ 0x11,
1712
+ 0x01,
1713
+ 0x10,
1714
+ 0x00);
1715
+ __m512i idx_high = _mm512_set_epi32(
1716
+ 0x1f,
1717
+ 0x0f,
1718
+ 0x1e,
1719
+ 0x0e,
1720
+ 0x1d,
1721
+ 0x0d,
1722
+ 0x1c,
1723
+ 0x0c,
1724
+ 0x1b,
1725
+ 0x0b,
1726
+ 0x1a,
1727
+ 0x0a,
1728
+ 0x19,
1729
+ 0x09,
1730
+ 0x18,
1731
+ 0x08);
1732
+
1733
+ // load scale and zero point
1734
+ auto load_scale_and_zeros = [&](int i, int _kb) {
1735
+ // load 2x bfloat16 vector
1736
+ __m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i));
1737
+ if (_kb + PREFETCH_SIZE_KB < KB) {
1738
+ _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0);
1739
+ }
1740
+
1741
+ // convert to 2x f32 vector
1742
+ __m512 a, b;
1743
+ at::vec::cvtbf16_fp32(t, a, b);
1744
+
1745
+ // transpose scale_and_zero from {16, 2} to {2, 16}
1746
+ // inputs:
1747
+ // a: {s0, z0, s1, z1, ..., s7, z7}
1748
+ // b: {s8, z8, s9, z9, ..., s15, z15}
1749
+ // output:
1750
+ // scale: {s0, s1, s2, ..., s15}
1751
+ // zero: {z0, z1, z2, ..., z15}
1752
+ scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
1753
+ zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
1754
+ };
1755
+
1756
+ // Dequantize a B block of 2 * block_n into bf16
1757
+ // So, it handles k and k+1 at the same time
1758
+ auto dequantize_B = [&](int n) {
1759
+ constexpr int64_t ldb_int4 = BLOCK_N / 2; // 16
1760
+ for (int k = 0, kb = 0; k < K; k += 2) {
1761
+ // Since block_k must be 32 for AMX microkernels, k_start may not be
1762
+ // a multiple of q_group_size. In that case, we need to load scales
1763
+ // and zero points immediately when k == 0 here
1764
+ if ({{kernel_name}}_is_block_start(k, k_start, q_group_size) || k == 0) {
1765
+ c10::ForcedUnroll<COLS>{}(load_scale_and_zeros, kb++);
1766
+ }
1767
+
1768
+ // load 256 bits = 64 elements in int4
1769
+ if (k + PREFETCH_SIZE_K < K) {
1770
+ _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb_int4, _MM_HINT_T0);
1771
+ }
1772
+
1773
+ __m128i b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + k * ldb_int4));
1774
+ b32[0] = _mm512_cvtepu8_epi32(b4);
1775
+ b32[1] = _mm512_srli_epi32(b32[0], 4);
1776
+ vb[0] = _mm512_permutexvar_ps(b32[0] , lut);
1777
+ vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]);
1778
+ vb[1] = _mm512_permutexvar_ps(b32[1], lut);
1779
+ vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]);
1780
+
1781
+ b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + (k + 1) * ldb_int4));
1782
+ b32[0 + COLS] = _mm512_cvtepu8_epi32(b4);
1783
+ b32[1 + COLS] = _mm512_srli_epi32(b32[0 + COLS], 4);
1784
+ vb[0 + COLS] = _mm512_permutexvar_ps(b32[0 + COLS] , lut);
1785
+ vb[0 + COLS] = _mm512_fmadd_ps(vb[0 + COLS], scale[0], zero[0]);
1786
+ vb[1 + COLS] = _mm512_permutexvar_ps(b32[1 + COLS], lut);
1787
+ vb[1 + COLS] = _mm512_fmadd_ps(vb[1 + COLS], scale[1], zero[1]);
1788
+
1789
+ for (int i = 0; i < COLS; i++) {
1790
+ // convert to VNNI
1791
+ auto low = _mm512_permutex2var_ps(vb[i], idx_low, vb[i + COLS]);
1792
+ auto high = _mm512_permutex2var_ps(vb[i], idx_high, vb[i + COLS]);
1793
+ // convert lower 16 float32 values to bfloat16
1794
+ auto v0_bf16 = reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(low));
1795
+ // convert higher 16 float32 values to bfloat16
1796
+ auto v1_bf16 = reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(high));
1797
+ // combine the lower 16 and higher 16 bfloat16 values
1798
+ auto v = _mm512_castsi256_si512(v0_bf16);
1799
+ v = _mm512_inserti64x4(v, v1_bf16, 1);
1800
+ // store the VNNI format bfloat16 values
1801
+ {{input_t}}* addr = dequantized_B_buf + k * 32 + (i % 2) * 32;
1802
+ _mm512_storeu_si512(addr, v);
1803
+ }
1804
+ }
1805
+ };
1806
+
1807
+ for (int64_t n = 0; n < N; n += {{block_n}}) {
1808
+ // Dequantize K * block_n int8 B elements into BF16
1809
+ dequantize_B(n);
1810
+ for (int64_t m = 0; m < M; m += {{block_m}}) {
1811
+ int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
1812
+ int64_t m_tail = m;
1813
+ {%- for num_rows in range(block_m, 0, -16) %}
1814
+ {%- if num_rows != block_m %}
1815
+ else
1816
+ {%- endif %}
1817
+ if (block_m >= {{num_rows}}) {
1818
+ {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}<accum>(
1819
+ amx_state,
1820
+ A + m * lda,
1821
+ dequantized_B_buf + n * K,
1822
+ C + m * ldc + n,
1823
+ K,
1824
+ lda,
1825
+ {{block_n}},
1826
+ ldc,
1827
+ 16
1828
+ );
1829
+ block_m -= {{num_rows}};
1830
+ m_tail += {{num_rows}};
1831
+ }
1832
+ {%- endfor %}
1833
+ if (block_m > 0) {
1834
+ {{kernel_name}}_amx_kernel_16_{{num_columns}}<accum>(
1835
+ amx_state,
1836
+ A + m_tail * lda,
1837
+ dequantized_B_buf + n * K,
1838
+ C + m_tail * ldc + n,
1839
+ K,
1840
+ lda,
1841
+ {{block_n}},
1842
+ ldc,
1843
+ block_m
1844
+ );
1845
+ }
1846
+ } // for m
1847
+ } // for n
1848
+ }
1849
+ """
1850
+
1851
+ def get_kernel_extra_args_declare(self) -> str:
1852
+ return (
1853
+ "AMXState& amx_state,\n"
1854
+ " const int64_t q_group_size,\n"
1855
+ " const c10::BFloat16* __restrict__ ScaleAndZeros,\n"
1856
+ " const int64_t lds,\n"
1857
+ " int64_t k_start,"
1858
+ )
1859
+
1860
+ def get_kernel_extra_args(self, **kwargs) -> list[str]:
1861
+ assert "kernel" in kwargs
1862
+ assert "qscale_and_zeros" in kwargs
1863
+ kernel = kwargs["kernel"]
1864
+ qscale_and_zeros = kwargs["qscale_and_zeros"]
1865
+ return [
1866
+ "amx_state,",
1867
+ "group_size,",
1868
+ f"&({kernel.index(qscale_and_zeros, [0, 0, 0])}),",
1869
+ "N * 2,", # lds
1870
+ "k_start,",
1871
+ ]
1872
+
1873
+ def is_woq_int4(self):
1874
+ return True
1875
+
1876
+
1877
+ def create_micro_gemm(
1878
+ name,
1879
+ m,
1880
+ n,
1881
+ k,
1882
+ input_dtype,
1883
+ input2_dtype,
1884
+ output_dtype=None,
1885
+ compute_dtype=None,
1886
+ alpha=1,
1887
+ num_threads=-1,
1888
+ use_ref=True,
1889
+ q_group_size=None,
1890
+ ) -> Optional[CppMicroGemm]:
1891
+ """
1892
+ Based on the provided info, try to find the config of the micro-kernel that would
1893
+ deliver the best performance in terms of lower latency for this case.
1894
+ """
1895
+
1896
+ def create_from_config(cls, config: CppMicroGemmConfig):
1897
+ return cls(
1898
+ name,
1899
+ config.input_dtype,
1900
+ config.input2_dtype,
1901
+ config.output_dtype,
1902
+ config.compute_dtype,
1903
+ config.register_blocking,
1904
+ alpha,
1905
+ )
1906
+
1907
+ def skip_amx_kernel_for_woq(config, dynamic_M, micro_gemm_cls):
1908
+ # For WoQ GEMM, AMX micro-kernel may not perform well if m is small.
1909
+ # Exception: for dynamic shapes, we consider using the AMX micro-kernel.
1910
+ if (
1911
+ dynamic_M
1912
+ or input_dtype != torch.bfloat16
1913
+ or input2_dtype not in [torch.int8, torch.uint8]
1914
+ ):
1915
+ return False
1916
+ # For WOQ INT8, use AMX for m >= block_m
1917
+ # For WOQ INT4, use AMX for m >= 5
1918
+ block_m, *_ = config.register_blocking
1919
+ is_woq_int4 = micro_gemm_cls == CppMicroGemmWoQInt4Amx
1920
+ m_threshold = 5 if is_woq_int4 else block_m
1921
+ return m < m_threshold
1922
+
1923
+ assert isinstance(n, int) or n.is_number, n
1924
+ assert isinstance(k, int) or k.is_number, k
1925
+ from ..utils import has_free_symbols
1926
+
1927
+ dynamic_M = has_free_symbols((m,))
1928
+ m = V.graph.sizevars.size_hint(m, fallback=1) if dynamic_M else m
1929
+ assert isinstance(m, int) or m.is_number, m
1930
+ if output_dtype is None:
1931
+ output_dtype = input_dtype
1932
+ if compute_dtype is None:
1933
+ compute_dtype = output_dtype
1934
+ if num_threads < 0:
1935
+ num_threads = parallel_num_threads()
1936
+ vec_isa = pick_vec_isa()
1937
+ matched_configs = []
1938
+ for cls, configs in micro_gemm_configs.items():
1939
+ for config in configs:
1940
+ if not issubclass(vec_isa.__class__, config.vec_isa_cls):
1941
+ continue
1942
+ if (
1943
+ config.input_dtype == input_dtype
1944
+ and config.compute_dtype == compute_dtype
1945
+ and config.input2_dtype == input2_dtype
1946
+ and config.output_dtype == output_dtype
1947
+ # The output_dtype here is the output dtype of the micro-kernel.
1948
+ # In some cases, the actual output dtype of the op for which the micro-kernel
1949
+ # is being created would be same as that of the activation, but the micro-kernels
1950
+ # compute output in Float/int32, which is converted in the GEMM template. This is
1951
+ # subject to change in the future.
1952
+ ):
1953
+ if config.extra_check is not None and not config.extra_check(
1954
+ config,
1955
+ m,
1956
+ n,
1957
+ k,
1958
+ alpha,
1959
+ num_threads,
1960
+ dynamic_M=dynamic_M,
1961
+ q_group_size=q_group_size,
1962
+ ):
1963
+ continue
1964
+ block_m, block_n, block_k = config.register_blocking
1965
+ if config.vec_isa_cls == VecAMX and skip_amx_kernel_for_woq(
1966
+ config, dynamic_M, cls
1967
+ ):
1968
+ continue
1969
+ # Criteria on the ranking of configurations
1970
+ # 1. ISA: AMX > VEC
1971
+ # 2. Dividable by block sizes (block_m, block_n, block_k)
1972
+ # 3. Number of mxn blocks is large enough to occupy all the threads
1973
+ # 4. Register blocks are larger
1974
+ isa_score = 0
1975
+ if config.vec_isa_cls == VecAMX:
1976
+ isa_score += 1
1977
+ dividable_score = 0
1978
+ if m % block_m == 0:
1979
+ dividable_score += 1
1980
+ if n % block_n == 0:
1981
+ dividable_score += 1
1982
+ if k % block_k == 0:
1983
+ dividable_score += 1
1984
+ occupancy_score = 0
1985
+ n_blocks = (n + block_n - 1) // block_n
1986
+ total_mxn_blocks = n_blocks * ((m + block_m - 1) // block_m)
1987
+ if n_blocks >= num_threads:
1988
+ occupancy_score += 1
1989
+ if total_mxn_blocks >= num_threads:
1990
+ occupancy_score += 1
1991
+ register_bytes = (
1992
+ block_m * block_n * config.compute_dtype.itemsize
1993
+ + (block_m * block_k + block_k * block_n)
1994
+ * config.input_dtype.itemsize
1995
+ )
1996
+ matched_configs.append(
1997
+ (
1998
+ (isa_score, dividable_score, occupancy_score, register_bytes),
1999
+ cls,
2000
+ config,
2001
+ )
2002
+ )
2003
+ if len(matched_configs) == 0:
2004
+ if use_ref:
2005
+ return CppMicroGemmRef(
2006
+ name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha
2007
+ )
2008
+ else:
2009
+ return None
2010
+ # TODO(jgong5): allow autotuning on choices of configs
2011
+ return create_from_config(*max(matched_configs, key=operator.itemgetter(0))[1:])
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import ctypes
3
+ import functools
4
+ import itertools
5
+ import logging
6
+ import sys
7
+ from collections.abc import Iterable
8
+ from typing import Callable, Optional, Union
9
+ from unittest.mock import patch
10
+
11
+ import sympy
12
+
13
+ from .. import config, ir
14
+ from ..autotune_process import CppBenchmarkRequest, TensorMeta
15
+ from ..utils import IndentedBuffer, Placeholder, unique
16
+ from ..virtualized import V
17
+ from .common import KernelTemplate
18
+ from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel
19
+
20
+
21
+ log = logging.getLogger(__name__)
22
+
23
+
24
+ class CppTemplate(KernelTemplate):
25
+ index_counter = itertools.count()
26
+
27
+ def __init__(
28
+ self,
29
+ name: str,
30
+ input_nodes,
31
+ layout: ir.Layout,
32
+ num_threads: int,
33
+ epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
34
+ ) -> None:
35
+ super().__init__(name)
36
+ self.input_nodes = input_nodes
37
+ self.index = next(self.index_counter)
38
+ self.output_node: Union[ir.Buffer, list[ir.Buffer]] = ir.Buffer(
39
+ name=f"buf_out{self.index}", layout=layout
40
+ )
41
+ self.layout = layout
42
+ self.num_threads = num_threads
43
+ self.epilogue_creator = epilogue_creator
44
+
45
+ def generate(self, **kwargs):
46
+ kernel_name = f"cpp_{self.name}"
47
+ with (
48
+ patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
49
+ patch.object(ir.FlexibleLayout, "allow_indexing", True),
50
+ V.graph.set_current_device(self.layout.device),
51
+ CppTemplateKernel(
52
+ kernel_name=kernel_name, num_threads=self.num_threads
53
+ ) as kernel,
54
+ ):
55
+ code = kernel.render(self, **kwargs)
56
+ _, call_args, _, _ = kernel.args.python_argdefs()
57
+ log.debug("Generated Code:\n%s", code)
58
+ log.debug(
59
+ "Args: cpp_argdefs: %s, python_argdefs: %s",
60
+ kernel.args.cpp_argdefs(),
61
+ kernel.args.python_argdefs(),
62
+ )
63
+
64
+ expected_args = list(
65
+ unique(input_node.get_name() for input_node in self.input_nodes)
66
+ )
67
+ if isinstance(self.output_node, Iterable):
68
+ expected_args.extend([node.get_name() for node in self.output_node])
69
+ else:
70
+ expected_args.extend([self.output_node.get_name()])
71
+ assert list(call_args)[: len(expected_args)] == expected_args, (
72
+ call_args,
73
+ expected_args,
74
+ )
75
+ extra_args = V.graph.sizevars.size_hints(
76
+ map(sympy.expand, call_args[len(expected_args) :])
77
+ )
78
+ # Cast the size hint from int to ctypes.c_ulonglong explicitly
79
+ # since in cpp kernel, we bind it to C long
80
+ extra_args = tuple(ctypes.c_ulonglong(x) for x in extra_args)
81
+
82
+ kernel_hash_name = f"cpp_{self.name}_{self.index}"
83
+
84
+ # Create the BenchmarkRequest for CPP
85
+ bmreq = CppBenchmarkRequest(
86
+ kernel_name=kernel_name,
87
+ input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
88
+ output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
89
+ extra_args=extra_args,
90
+ source_code=code,
91
+ )
92
+
93
+ def make_kernel_render(
94
+ template_node: ir.CppTemplateBuffer,
95
+ flag_template_buffer_has_other_users: bool,
96
+ epilogue_nodes: Optional[list[ir.IRNode]] = None,
97
+ ):
98
+ kernel = CppTemplateKernel(
99
+ kernel_name=str(Placeholder.KERNEL_NAME), num_threads=self.num_threads
100
+ )
101
+ render = functools.partial(
102
+ kernel.render,
103
+ self,
104
+ template_buffer_node=template_node,
105
+ flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
106
+ epilogue_nodes=epilogue_nodes,
107
+ **kwargs,
108
+ )
109
+ return kernel, render
110
+
111
+ return CppTemplateCaller(
112
+ kernel_hash_name,
113
+ self.name,
114
+ self.input_nodes,
115
+ self.output_node[0].get_layout()
116
+ if isinstance(self.output_node, Iterable)
117
+ else self.output_node.get_layout(),
118
+ make_kernel_render,
119
+ bmreq,
120
+ self,
121
+ )
122
+
123
+ def header(self) -> IndentedBuffer:
124
+ res = IndentedBuffer()
125
+ res.writeline("#include <torch/csrc/inductor/cpp_prefix.h>")
126
+ # TODO: add c10::ForcedUnroll test to test_aoti_abi_check
127
+ res.splice("""#include <c10/util/Unroll.h>""")
128
+ res.splice("""#include <torch/csrc/inductor/aoti_torch/c/shim.h>""")
129
+ enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
130
+ "linux",
131
+ "win32",
132
+ ]
133
+ if enable_kernel_profile:
134
+ res.writelines(["#include <ATen/record_function.h>"])
135
+ return res
136
+
137
+ def render(self, **kwargs) -> str:
138
+ raise NotImplementedError
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template_kernel.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import itertools
3
+ from collections.abc import Iterable
4
+ from typing import Any, Callable, Optional, Union
5
+
6
+ import sympy
7
+ from sympy.parsing.sympy_parser import parse_expr
8
+
9
+ import torch
10
+ from torch._inductor.utils import do_bench_using_profiling
11
+ from torch.utils._ordered_set import OrderedSet
12
+ from torch.utils._sympy.symbol import SymT
13
+
14
+ from .. import config, cpp_builder, ir, lowering as L
15
+ from ..autotune_process import CppBenchmarkRequest
16
+ from ..loop_body import LoopBody
17
+ from ..select_algorithm import PartialRender
18
+ from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix
19
+ from ..virtualized import V
20
+ from .common import REMOVED
21
+ from .cpp import CppKernel, CppKernelProxy, KernelGroup
22
+ from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext
23
+
24
+
25
+ def parse_expr_with_index_symbols(expr):
26
+ if isinstance(expr, sympy.Expr):
27
+ return expr
28
+ elif isinstance(expr, (list, tuple)):
29
+ return [parse_expr_with_index_symbols(e) for e in expr]
30
+ else:
31
+ expr = parse_expr(str(expr))
32
+ int_symbols = {sym: sympy_index_symbol(sym.name) for sym in expr.free_symbols}
33
+ return expr.subs(int_symbols)
34
+
35
+
36
+ def wrap_with_tensorbox(node) -> ir.TensorBox:
37
+ return (
38
+ ir.TensorBox.create(node) if isinstance(node, ir.Buffer) else ir.TensorBox(node)
39
+ )
40
+
41
+
42
+ class CppTemplateKernel(CppKernel):
43
+ def __init__(self, kernel_name, num_threads):
44
+ super().__init__(None, num_threads)
45
+ self.kernel_name = kernel_name
46
+ self.render_hooks = {}
47
+ self.local_buffers = {}
48
+
49
+ def render(self, template, **kwargs):
50
+ return PartialRender(
51
+ template.render(kernel=self, **kwargs), self.render_hooks
52
+ ).finalize_all()
53
+
54
+ def def_kernel(
55
+ self,
56
+ inputs: dict[str, ir.Buffer],
57
+ outputs: dict[str, ir.Buffer],
58
+ aliases: Optional[dict[str, str]] = None,
59
+ function_name: str = "",
60
+ extra_sizevars: Optional[list[sympy.Expr]] = None,
61
+ placeholder: str = "<DEF_KERNEL>",
62
+ ) -> str:
63
+ if len(function_name) == 0:
64
+ function_name = str(self.kernel_name)
65
+ for name, inp in inputs.items():
66
+ if inp is not None:
67
+ self.args.input_buffers[inp.get_name()] = name
68
+ for name, out in outputs.items():
69
+ self.args.output_buffers[out.get_name()] = name
70
+ if aliases is not None:
71
+ for alias, orig in aliases.items():
72
+ if orig in self.args.input_buffers:
73
+ self.args.input_buffers[alias] = self.args.input_buffers[orig]
74
+ if orig in self.args.output_buffers:
75
+ self.args.output_buffers[alias] = self.args.output_buffers[orig]
76
+
77
+ unique_sizevars = OrderedSet(
78
+ s
79
+ for input in inputs.values()
80
+ if input is not None
81
+ for sym in itertools.chain(input.get_size(), input.get_stride())
82
+ if isinstance(sym, sympy.Expr)
83
+ for s in sym.free_symbols
84
+ )
85
+ unique_sizevars.update(
86
+ s
87
+ for sym in extra_sizevars or []
88
+ if isinstance(sym, sympy.Expr)
89
+ for s in sym.free_symbols
90
+ )
91
+ unique_sizevars.update(
92
+ s
93
+ for output in outputs.values()
94
+ for sym in itertools.chain(output.get_size(), output.get_stride())
95
+ if isinstance(sym, sympy.Expr)
96
+ for s in sym.free_symbols
97
+ )
98
+ sizevars = sorted(unique_sizevars, key=str)
99
+ for sizevar in sizevars:
100
+ self.args.sizevars[sizevar] = f"k{sizevar}"
101
+
102
+ def hook():
103
+ # remove all aliases before generate function definition
104
+ if aliases is not None:
105
+ for alias in aliases:
106
+ if alias in self.args.input_buffers:
107
+ raise AssertionError(
108
+ f"input_buffers cannot be removed: {alias}"
109
+ )
110
+ if alias in self.args.output_buffers:
111
+ self.args.output_buffers[alias] = REMOVED
112
+ cpp_argdefs, _, _ = self.args.cpp_argdefs()
113
+ return f"void {function_name}({', '.join(cpp_argdefs)})"
114
+
115
+ assert placeholder not in self.render_hooks
116
+ self.render_hooks[placeholder] = hook
117
+ return placeholder
118
+
119
+ def call_kernel(self, name: str, node: ir.CppTemplateBuffer):
120
+ wrapper = V.graph.wrapper_code
121
+ _, call_args, arg_types = self.args.cpp_argdefs()
122
+ wrapper.generate_kernel_call(name, call_args, triton=False, arg_types=arg_types)
123
+
124
+ def dtype(self, node: ir.Buffer) -> str:
125
+ return DTYPE_TO_CPP[node.get_dtype()]
126
+
127
+ def acc_dtype(self, node: ir.Buffer) -> str:
128
+ if node.get_dtype() in [torch.float32, torch.bfloat16, torch.half]:
129
+ return "float"
130
+ else:
131
+ raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}")
132
+
133
+ def size(self, node: ir.Buffer, dim: int) -> str:
134
+ return cexpr_index(self.rename_indexing(node.get_size()[dim]))
135
+
136
+ def stride(self, node: ir.Buffer, dim: int) -> str:
137
+ return cexpr_index(self.rename_indexing(node.get_stride()[dim]))
138
+
139
+ def index(self, node: ir.Buffer, indices: list[Any]) -> str:
140
+ indexer = node.get_layout().as_fixed().make_indexer()
141
+ index = indexer(parse_expr_with_index_symbols(indices))
142
+ index = self.rename_indexing(index)
143
+ outer_name = node.get_name()
144
+ inner_name = (
145
+ outer_name
146
+ if outer_name in self.local_buffers
147
+ else self.args.input(node.get_name())
148
+ )
149
+ return f"{inner_name}[{cexpr_index(index)}]"
150
+
151
+ def slice_nd(self, node, ranges: list[tuple[Any, Any]]) -> ir.ReinterpretView:
152
+ """
153
+ Slice the given node with a list of ranges (start and end) corresponding to its dims.
154
+ The dim is not sliced if the corresponding range is empty.
155
+ """
156
+ assert len(ranges) == len(node.get_size()), f"{ranges=}, {node=}"
157
+ sliced = wrap_with_tensorbox(node)
158
+ for dim, _range in enumerate(ranges):
159
+ if len(_range) == 0:
160
+ continue
161
+ assert len(_range) == 2
162
+ start, end = parse_expr_with_index_symbols(_range)
163
+ sliced = L.slice_(sliced, dim, start, end, clamp=False)
164
+ assert isinstance(sliced.data, ir.ReinterpretView), sliced.data
165
+ return sliced.data
166
+
167
+ def select(self, node, dim: int, idx: int) -> ir.ReinterpretView:
168
+ # We avoid using L.select here because we need clamp=False so the dim after slicing
169
+ # is 1 instead of a sympy expression of symbol - dim_size.
170
+ node = wrap_with_tensorbox(node)
171
+ idx = ir.View.handle_negative_index(idx, node.get_size()[dim])
172
+ sliced = L.squeeze(L.slice_(node, dim, idx, idx + 1, clamp=False), dim)
173
+ assert isinstance(sliced.data, ir.ReinterpretView), sliced.data
174
+ return sliced.data
175
+
176
+ def view(self, node, sizes: list[Any]) -> ir.View:
177
+ node = wrap_with_tensorbox(node)
178
+ sizes = parse_expr_with_index_symbols(sizes)
179
+ return L.view(node, sizes).data
180
+
181
+ def permute(self, node, dims):
182
+ node = wrap_with_tensorbox(node)
183
+ permuted = L.permute(node, dims).data
184
+ assert isinstance(permuted, ir.ReinterpretView)
185
+ return permuted
186
+
187
+ def maybe_codegen_profile(self) -> str:
188
+ if config.cpp.enable_kernel_profile:
189
+ graph_id = V.graph.graph_id
190
+ prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else ""
191
+ return f'RECORD_FUNCTION("{prefix}{self.kernel_name}", c10::ArrayRef<c10::IValue>({{}}));'
192
+ else:
193
+ return ""
194
+
195
+ def unroll_pragma(self, unroll):
196
+ if cpp_builder.is_gcc():
197
+ return f"#pragma GCC unroll {unroll}"
198
+ else:
199
+ return f"#pragma unroll {unroll}"
200
+
201
+ def define_buffer(self, name, sizes: list[Any], dtype=torch.float) -> str:
202
+ """Define kernel local buffer"""
203
+ sizes = parse_expr_with_index_symbols(sizes)
204
+ buf = ir.Buffer(
205
+ name=name, layout=ir.FixedLayout(torch.device("cpu"), dtype, sizes)
206
+ )
207
+ self.local_buffers[name] = buf
208
+ ctype = f"{DTYPE_TO_CPP[dtype]}"
209
+ numel = f"{cexpr_index(buf.get_numel())}"
210
+ return f"auto _{name} = std::make_unique<{ctype}[]>({numel}); auto {name} = _{name}.get();"
211
+
212
+ def define_stack_allocated_buffer(
213
+ self, name, sizes: list[Any], dtype=torch.float
214
+ ) -> str:
215
+ """Define stack-allocated buffer"""
216
+ sizes = parse_expr_with_index_symbols(sizes)
217
+ buf = ir.Buffer(
218
+ name=name, layout=ir.FixedLayout(torch.device("cpu"), dtype, sizes)
219
+ )
220
+ self.local_buffers[name] = buf
221
+ ctype = f"{DTYPE_TO_CPP[dtype]}"
222
+ numel = f"{cexpr_index(buf.get_numel())}"
223
+ return f"alignas(64) {ctype} _{name}[{numel}]; {ctype}* {name} = _{name};"
224
+
225
+ def reinit_buffer_if_null(self, name):
226
+ """Reinit the previously defined local buffer if it is null"""
227
+ assert name in self.local_buffers
228
+ buf = self.local_buffers[name]
229
+ ctype = f"{DTYPE_TO_CPP[buf.layout.dtype]}"
230
+ numel = f"{cexpr_index(buf.get_numel())}"
231
+ return f"if (_{name} == nullptr) {{ _{name} = std::make_unique<{ctype}[]>({numel}); {name} = _{name}.get(); }}"
232
+
233
+ def release_buffer(self, name):
234
+ """Codegen the code to release the ownership of a local buffer to others"""
235
+ assert name in self.local_buffers
236
+ return f"_{name}.release()"
237
+
238
+ def store_pointwise_nodes(
239
+ self,
240
+ dst: ir.Buffer,
241
+ nodes: list[ir.IRNode],
242
+ offsets: Optional[list[sympy.Expr]] = None,
243
+ reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None,
244
+ ) -> str:
245
+ var_sizes = (tuple(dst.get_size()), ())
246
+ var_ranges = {
247
+ sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
248
+ for i, sz in enumerate(var_sizes[0])
249
+ }
250
+ if not offsets:
251
+ offsets = [sympy.S.Zero] * len(var_sizes[0])
252
+ if not reindexers:
253
+ reindexers = [None] * len(nodes)
254
+ assert len(offsets) == len(var_sizes[0])
255
+ output_index = dst.get_layout().make_indexer()([*var_ranges.keys()])
256
+ kernel_group = KernelGroup()
257
+ kernel_group.args = self.args
258
+ cpp_kernel_proxy = CppKernelProxy(kernel_group)
259
+ bodies = []
260
+ var_sizes_list = []
261
+ for i, node in enumerate(nodes):
262
+ output_name = node.get_name() if i < len(nodes) - 1 else dst.get_name()
263
+ node = node.data if isinstance(node, ir.ComputedBuffer) else node
264
+ assert isinstance(node, ir.Pointwise), node
265
+
266
+ def fn(*args):
267
+ assert len(args) == 2
268
+ assert len(args[0]) == len(var_sizes[0])
269
+ assert len(args[1]) == 0
270
+ new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type]
271
+ if reindexers[i] is not None:
272
+ new_args = reindexers[i](new_args) # type: ignore[misc]
273
+ V.ops.store(
274
+ output_name,
275
+ output_index,
276
+ node.make_loader()(new_args).value,
277
+ )
278
+
279
+ body = LoopBody(
280
+ fn,
281
+ (list(var_ranges.keys()), ()),
282
+ var_ranges,
283
+ list(var_ranges.keys()),
284
+ tuple(),
285
+ )
286
+ bodies.append(body)
287
+ var_sizes_list.append(var_sizes)
288
+
289
+ cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
290
+ kernel_group.finalize_kernel(cpp_kernel_proxy, [])
291
+ return kernel_group.loops_code.getvalue()
292
+
293
+ def store_grouped_gemm_pointwise_nodes(
294
+ self,
295
+ dst: tuple[ir.Buffer],
296
+ nodes: list[ir.IRNode],
297
+ offsets: list[sympy.Expr],
298
+ reindexers: list[Optional[Callable[[list[Any]], list[Any]]]],
299
+ output_names: list[str],
300
+ ) -> str:
301
+ ref_dst = dst[0]
302
+ var_sizes = (tuple(ref_dst.get_size()), ())
303
+ var_ranges = {
304
+ sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
305
+ for i, sz in enumerate(var_sizes[0])
306
+ }
307
+ assert offsets, "offsets should be set outside"
308
+ assert all(len(offset) == len(var_sizes[0]) for offset in offsets)
309
+ output_index = ref_dst.get_layout().make_indexer()([*var_ranges.keys()])
310
+ kernel_group = KernelGroup()
311
+ kernel_group.args = self.args
312
+ cpp_kernel_proxy = CppKernelProxy(kernel_group)
313
+ bodies = []
314
+ var_sizes_list = []
315
+ for i, node in enumerate(nodes):
316
+ output_name = output_names[i]
317
+ node = node.data if isinstance(node, ir.ComputedBuffer) else node
318
+ assert isinstance(node, ir.Pointwise), node
319
+
320
+ def fn(*args):
321
+ assert len(args) == 2
322
+ assert len(args[0]) == len(var_sizes[0])
323
+ assert len(args[1]) == 0
324
+ new_args = [arg + offset for arg, offset in zip(args[0], offsets[i])] # type: ignore[arg-type]
325
+ if reindexers[i] is not None:
326
+ new_args = reindexers[i](new_args) # type: ignore[misc]
327
+ V.ops.store(
328
+ output_name,
329
+ output_index,
330
+ node.make_loader()(new_args).value,
331
+ )
332
+
333
+ body = LoopBody(
334
+ fn,
335
+ (list(var_ranges.keys()), ()),
336
+ var_ranges,
337
+ list(var_ranges.keys()),
338
+ tuple(),
339
+ )
340
+ bodies.append(body)
341
+ var_sizes_list.append(var_sizes)
342
+
343
+ cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
344
+ kernel_group.finalize_kernel(cpp_kernel_proxy, [])
345
+ return kernel_group.loops_code.getvalue()
346
+
347
+ def store_output(
348
+ self,
349
+ dst: ir.Buffer,
350
+ src: ir.Buffer,
351
+ orig_src: Optional[ir.Buffer] = None,
352
+ epilogue_nodes: Optional[list[ir.IRNode]] = None,
353
+ offsets: Optional[list[Any]] = None,
354
+ reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None,
355
+ ):
356
+ """
357
+ Store the `src` buffer to the `dst` buffer. The size of `src` and `dst` should match.
358
+ If `epilogue_nodes` is provided, the `src` buffer is firstly computed with the epilogues
359
+ before stored to `dst`. The `epilogues_nodes` are all pointwise.
360
+
361
+ Notes:
362
+ 1. `src` and `dst` buffer could be the same buffer in which case we are doing in-place compute
363
+ and stores. In case `epilogue_nodes` are not provided, we do nothing.
364
+ 2. The `epilogue_nodes`, if exist, have computations on `src` before storing to `dst` but since
365
+ they come form the original Inductor IR, they might need to be adjusted before working with
366
+ `src` and `dst` as outlined below:
367
+ a) `src` or `dst` buffer could be a sub-slice of the ranges the `epilogue_nodes`work on.
368
+ In this case, the `offsets` could be provided to adjust the indices passed to
369
+ `epilogue_nodes` during codegen and the data ranges are also configured according to
370
+ the sizes of `src` and `dst`.
371
+ b) `dst` might be indexed in a different way as the `epilogue_nodes`, hence a `reindexer` is
372
+ needed on the indices to `epilogue_nodes` to match the indexing of `dst`.
373
+ c) If `src` is local, we need to add a local buffer for it and localize the `orig_src` buffer
374
+ in `epilogue_nodes` with `src`.
375
+ """
376
+ assert isinstance(dst, (ir.Buffer, ir.ReinterpretView))
377
+ assert dst.get_size() == src.get_size(), f"{dst=}, {src=}"
378
+ if offsets:
379
+ offsets = parse_expr_with_index_symbols(offsets)
380
+ if epilogue_nodes:
381
+ with LocalBufferContext(self.args) as scope:
382
+ assert orig_src is not None
383
+ if orig_src.get_name() != src.get_name():
384
+ scope.add_local_buffer(
385
+ src,
386
+ [
387
+ orig_src,
388
+ ],
389
+ )
390
+ epilogue_nodes = scope.localize_nodes(epilogue_nodes)
391
+ return self.store_pointwise_nodes(
392
+ dst,
393
+ epilogue_nodes, # type: ignore[arg-type]
394
+ offsets,
395
+ reindexers,
396
+ )
397
+ else:
398
+ if dst.get_name() != src.get_name():
399
+ # src is local
400
+ copy = L.copy(dst, src).data.data
401
+ with LocalBufferContext(self.args) as scope:
402
+ scope.add_local_buffer(src)
403
+ return self.store_pointwise_nodes(dst, [copy])
404
+ else:
405
+ assert dst.layout == src.layout, f"{dst=}, {src=}"
406
+ return ""
407
+
408
+ def store_outputs(
409
+ self,
410
+ dst: tuple[ir.Buffer],
411
+ src: tuple[ir.IRNode],
412
+ orig_src: Optional[tuple[ir.IRNode]] = None,
413
+ epilogue_nodes: Optional[list[ir.IRNode]] = None,
414
+ offsets: Optional[list[Any]] = None,
415
+ reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None,
416
+ multi_output_buffers: Optional[tuple[ir.MultiOutput]] = None,
417
+ ):
418
+ assert isinstance(dst, Iterable)
419
+ assert all(_dst.get_size() == _src.get_size() for _src, _dst in zip(src, dst))
420
+ if offsets:
421
+ offsets = parse_expr_with_index_symbols(offsets)
422
+ gemm_num = len(src)
423
+ final_offsets = []
424
+ output_names = []
425
+ if epilogue_nodes:
426
+ if not reindexers:
427
+ reindexers = [None] * len(epilogue_nodes)
428
+ with LocalBufferContext(self.args) as scope:
429
+ assert orig_src is not None
430
+ localize_epilogue_nodes = []
431
+ all_read_names = []
432
+ for epilogue in epilogue_nodes:
433
+ all_read_names.extend(list(epilogue.get_read_names()))
434
+ localize_epilogue_nodes.extend(scope.localize_nodes(epilogue_nodes))
435
+ final_offsets.extend([offsets] * len(localize_epilogue_nodes))
436
+ output_names.extend(
437
+ [node.get_name() for node in localize_epilogue_nodes]
438
+ )
439
+ for gemm_idx in range(gemm_num):
440
+ if orig_src[gemm_idx].get_name() != src[gemm_idx].get_name():
441
+ if orig_src[gemm_idx].get_name() in all_read_names or (
442
+ multi_output_buffers
443
+ and multi_output_buffers[gemm_idx].get_name()
444
+ in all_read_names
445
+ ):
446
+ # If any of the Epilogue nodes use this GEMM output, let's localize the GEMM output
447
+ global_buffers = [orig_src[gemm_idx]]
448
+ if (
449
+ multi_output_buffers
450
+ and multi_output_buffers[gemm_idx].get_name()
451
+ in all_read_names
452
+ and orig_src[gemm_idx].get_name() not in all_read_names
453
+ ):
454
+ # Epilogue might directly read the MultiOutput, Locallize MultiOutput to the local Buffer
455
+ # if this MultiOutput has not been stored by in-template epilogue
456
+ # otherwise, use the cse store cache if it will be stored before used
457
+ global_buffers.append(multi_output_buffers[gemm_idx])
458
+ scope.add_local_buffer(
459
+ src[gemm_idx],
460
+ global_buffers,
461
+ )
462
+ else:
463
+ scope.add_local_buffer(src[gemm_idx])
464
+ localize_epilogue_nodes.extend(
465
+ [L.copy(dst[gemm_idx], src[gemm_idx]).data.data]
466
+ )
467
+ reindexers.append(None)
468
+ output_names.append(dst[gemm_idx].get_name())
469
+ final_offsets.append(
470
+ [sympy.S.Zero] * len(dst[gemm_idx].get_size())
471
+ )
472
+ res = self.store_grouped_gemm_pointwise_nodes(
473
+ dst,
474
+ localize_epilogue_nodes,
475
+ final_offsets,
476
+ reindexers,
477
+ output_names=output_names,
478
+ )
479
+ for gemm_idx in range(gemm_num):
480
+ if (
481
+ multi_output_buffers
482
+ and multi_output_buffers[gemm_idx].get_name() in all_read_names
483
+ ):
484
+ # If the MultiOutput is used in the Epilogue, let's remove it from args
485
+ multi_output_name = multi_output_buffers[gemm_idx].get_name()
486
+ if (
487
+ multi_output_name in self.args.output_buffers
488
+ and self.args.output_buffers[multi_output_name]
489
+ is not REMOVED
490
+ ):
491
+ self.remove_buffer(multi_output_name)
492
+ return res
493
+ else:
494
+ if dst[0].get_name() != src[0].get_name():
495
+ copy_list = []
496
+ with LocalBufferContext(self.args) as scope:
497
+ for _src, _dst in zip(src, dst):
498
+ copy_list.extend([L.copy(_dst, _src).data.data])
499
+ scope.add_local_buffer(_src)
500
+ output_names.append(_dst.get_name())
501
+ final_offsets.append([sympy.S.Zero] * len(_dst.get_size()))
502
+ reindexers = [None] * len(copy_list)
503
+ return self.store_grouped_gemm_pointwise_nodes(
504
+ dst,
505
+ nodes=copy_list,
506
+ offsets=final_offsets,
507
+ reindexers=reindexers,
508
+ output_names=output_names,
509
+ )
510
+ else:
511
+ assert all(
512
+ _src.get_name() == _dst.get_name() for _src, _dst in zip(src, dst)
513
+ )
514
+ assert all(
515
+ _src.get_layout() == _dst.get_layout()
516
+ for _src, _dst in zip(src, dst)
517
+ )
518
+ return ""
519
+
520
+ def check_bounds(self, expr, size, lower, upper):
521
+ # CppTemplateKernel does not need codegen related operations
522
+ return
523
+
524
+
525
+ class CppTemplateCaller(ir.ChoiceCaller):
526
+ """
527
+ CppTemplateCaller
528
+
529
+ This class represents a caller for CPP template kernels. It is a subclass of ir.ChoiceCaller.
530
+ Attributes:
531
+ name (str): The name of the caller.
532
+ category (str): The category of the caller.
533
+ bmreq (CppBenchmarkRequest): The benchmark request for the caller.
534
+ template_buffer (ir.CppTemplateBuffer): The template buffer for the caller.
535
+ """
536
+
537
+ def __init__(
538
+ self,
539
+ name: str,
540
+ category: str,
541
+ input_nodes: list[ir.Buffer],
542
+ layout: ir.Layout,
543
+ make_kernel_render: Callable[
544
+ [
545
+ ir.CppTemplateBuffer,
546
+ bool,
547
+ Optional[list[ir.IRNode]],
548
+ ],
549
+ str,
550
+ ],
551
+ bmreq: CppBenchmarkRequest,
552
+ template: "CppTemplate", # type: ignore[name-defined] # noqa: F821
553
+ info_kwargs: Optional[
554
+ dict[str, Union[ir.PrimitiveInfoType, list[ir.PrimitiveInfoType]]]
555
+ ] = None,
556
+ ):
557
+ super().__init__(name, input_nodes, layout, description="")
558
+ self.category = category
559
+ self.make_kernel_render = make_kernel_render
560
+ self.bmreq = bmreq
561
+ self.template = template
562
+ self.info_kwargs = info_kwargs
563
+
564
+ def precompile(self) -> None:
565
+ assert self.bmreq is not None
566
+ self.bmreq.precompile()
567
+
568
+ def benchmark(self, *args, out) -> float:
569
+ assert self.bmreq is not None
570
+ if config.profile_bandwidth_with_do_bench_using_profiling:
571
+ algo = self.bmreq.make_run_fn(*args, out=out)
572
+ return do_bench_using_profiling(algo)
573
+ return self.bmreq.benchmark(*args, out=out)
574
+
575
+ def hash_key(self) -> str:
576
+ return "-".join(
577
+ [
578
+ self.category,
579
+ self.bmreq.hash_key,
580
+ ]
581
+ )
582
+
583
+ def info_dict(
584
+ self,
585
+ ) -> dict[str, Union[ir.PrimitiveInfoType, list[ir.PrimitiveInfoType]]]:
586
+ return {"backend": "CPP", "op_type": "unknown"}
587
+
588
+ def output_node(self) -> ir.TensorBox:
589
+ return ir.TensorBox.create(
590
+ ir.CppTemplateBuffer(
591
+ layout=self.layout,
592
+ inputs=self.input_nodes,
593
+ make_kernel_render=self.make_kernel_render,
594
+ template=self.template,
595
+ choice=self,
596
+ )
597
+ )
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_utils.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import contextlib
3
+ import dataclasses
4
+ import functools
5
+ import math
6
+ import sys
7
+ from collections import namedtuple
8
+ from collections.abc import Sequence
9
+ from typing import Any, Callable, Optional
10
+ from unittest.mock import patch
11
+
12
+ import sympy
13
+
14
+ import torch
15
+ from torch._prims_common import is_integer_dtype
16
+ from torch.utils._ordered_set import OrderedSet
17
+ from torch.utils._sympy.printers import CppPrinter as _CppPrinter
18
+ from torch.utils._sympy.symbol import symbol_is_type, SymT
19
+ from torch.utils._sympy.value_ranges import ValueRanges
20
+
21
+ from .. import ir
22
+ from ..dependencies import Dep
23
+ from ..loop_body import LoopBody
24
+ from ..scheduler import BaseSchedulerNode, SchedulerBuffer
25
+ from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
26
+ from ..virtualized import ops, OpsValue, V
27
+ from .common import CSEVariable, Kernel, KernelArgs, OptimizationContext
28
+
29
+
30
+ DTYPE_TO_CPP = {
31
+ torch.float32: "float",
32
+ torch.float64: "double",
33
+ torch.float16: "at::Half",
34
+ torch.int64: "int64_t",
35
+ torch.int32: "int32_t",
36
+ torch.int16: "int16_t",
37
+ torch.int8: "int8_t",
38
+ torch.uint64: "uint64_t",
39
+ torch.uint32: "uint32_t",
40
+ torch.uint16: "uint16_t",
41
+ torch.uint8: "uint8_t",
42
+ torch.bool: "bool",
43
+ torch.bfloat16: "at::BFloat16",
44
+ torch.complex32: "at::complex<at::Half>",
45
+ torch.complex64: "at::complex<float>",
46
+ torch.complex128: "at::complex<double>",
47
+ torch.float8_e4m3fn: "at::Float8_e4m3fn",
48
+ torch.float8_e5m2: "at::Float8_e5m2",
49
+ torch.float8_e4m3fnuz: "at::Float8_e4m3fnuz",
50
+ torch.float8_e5m2fnuz: "at::Float8_e5m2fnuz",
51
+ }
52
+
53
+ DTYPE_TO_ATEN = {
54
+ torch.float32: "at::kFloat",
55
+ torch.float64: "at::kDouble",
56
+ torch.float16: "at::kHalf",
57
+ torch.int64: "at::kLong",
58
+ torch.int32: "at::kInt",
59
+ torch.int16: "at::kShort",
60
+ torch.int8: "at::kChar",
61
+ torch.uint64: "at::kUInt64",
62
+ torch.uint32: "at::kUInt32",
63
+ torch.uint16: "at::kUInt16",
64
+ torch.uint8: "at::kByte",
65
+ torch.uint32: "at::kUInt32",
66
+ torch.uint64: "at::kUInt64",
67
+ torch.bool: "at::kBool",
68
+ torch.bfloat16: "at::kBFloat16",
69
+ torch.complex32: "at::kComplexHalf",
70
+ torch.complex64: "at::kComplexFloat",
71
+ torch.complex128: "at::kComplexDouble",
72
+ torch.float8_e4m3fn: "at::kFloat8_e4m3fn",
73
+ torch.float8_e5m2: "at::kFloat8_e5m2",
74
+ torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz",
75
+ torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz",
76
+ }
77
+
78
+ DEVICE_TO_ATEN = {
79
+ "meta": "at::kMeta",
80
+ "cpu": "at::kCPU",
81
+ "cuda": "at::kCUDA",
82
+ "xpu": "at::kXPU",
83
+ "mps": "at::kMPS",
84
+ }
85
+
86
+ LAYOUT_TO_ATEN = {
87
+ torch.strided: "at::kStrided",
88
+ torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined]
89
+ }
90
+
91
+ # matches c10/core/DeviceType.h
92
+ DEVICE_TO_INT = {"cpu": 0, "cuda": 1}
93
+
94
+ _IS_WINDOWS = sys.platform == "win32"
95
+
96
+ INDEX_TYPE = "int64_t"
97
+
98
+ GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"])
99
+
100
+
101
+ def get_promote_dtype(args):
102
+ return (
103
+ functools.reduce(
104
+ torch.promote_types, # type: ignore[arg-type]
105
+ [n.dtype for n in args if isinstance(n, CppCSEVariable)],
106
+ )
107
+ if all(n.dtype is not None for n in args if isinstance(n, CppCSEVariable))
108
+ else None # not enough info to calculate the promote dtype
109
+ )
110
+
111
+
112
+ def promote_args(new_args):
113
+ def promote_arg(arg, promote_type):
114
+ if (
115
+ isinstance(arg, CppCSEVariable)
116
+ and arg.dtype
117
+ and promote_type
118
+ and arg.dtype != promote_type
119
+ ):
120
+ arg = ops.to_dtype(arg, promote_type)
121
+ arg = arg.value if isinstance(arg, OpsValue) else arg
122
+ arg.dtype = promote_type
123
+ return arg
124
+
125
+ promote_type = get_promote_dtype(new_args)
126
+ promote_fn = functools.partial(
127
+ promote_arg,
128
+ promote_type=promote_type,
129
+ )
130
+ if (
131
+ all(
132
+ new_arg.dtype is not None
133
+ for new_arg in new_args
134
+ if isinstance(new_arg, CppCSEVariable)
135
+ )
136
+ and promote_type
137
+ ):
138
+ new_args = list(map(promote_fn, new_args))
139
+ return new_args
140
+
141
+
142
+ class CppCSEVariable(CSEVariable):
143
+ def __init__(
144
+ self,
145
+ name,
146
+ bounds: ValueRanges[Any],
147
+ dtype: Optional[torch.dtype] = None,
148
+ ) -> None:
149
+ super().__init__(name, bounds, dtype)
150
+ self.is_vec = False
151
+ self.dependent_itervars = OrderedSet[sympy.Symbol]()
152
+
153
+ def __repr__(self) -> str:
154
+ return (
155
+ f"CppCSEVariable(name: {self.name}, bounds: {self.bounds}, is_vec: {self.is_vec}, dtype: {self.dtype}, "
156
+ f"dependent_itervars: {self.dependent_itervars})"
157
+ )
158
+
159
+ def update_on_args(self, name, args, kwargs):
160
+ if name == "load":
161
+ # args[2] is index
162
+ self._set_dependent_itervars(args[2])
163
+ else:
164
+ # propagate relevant itervars and is_vec from args
165
+ self.dependent_itervars.update(
166
+ *[
167
+ arg.dependent_itervars
168
+ for arg in args
169
+ if isinstance(arg, CppCSEVariable)
170
+ ]
171
+ )
172
+ if name == "index_expr":
173
+ self._set_dependent_itervars(args[0])
174
+ if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)):
175
+ self.is_vec = True
176
+
177
+ def _set_dependent_itervars(self, index: sympy.Expr):
178
+ """
179
+ Set the relevant itervars for this variable based on the `index` expression.
180
+ This includes the itervars directly used in the `index` as well as relevant itervars
181
+ of other cse variables used in the `index`.
182
+ """
183
+ for s in index.free_symbols:
184
+ if s in V.kernel.itervars:
185
+ self.dependent_itervars.add(s) # type: ignore[arg-type]
186
+ elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined]
187
+ self.dependent_itervars.update(
188
+ V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined]
189
+ )
190
+
191
+ def depends_on(self, itervar: sympy.Symbol):
192
+ return itervar in self.dependent_itervars
193
+
194
+
195
+ class CppPrinter(_CppPrinter):
196
+ def doprint(self, expr, *, simplify: bool = True, p=True):
197
+ # TODO: why are people passing strings to the printer here :think:
198
+ if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
199
+ expr = V.graph.sizevars.simplify(expr)
200
+ return super().doprint(expr)
201
+
202
+
203
+ # A function to print, useful for printing sympy symbols.
204
+ cexpr = CppPrinter().doprint
205
+
206
+
207
+ def cexpr_index(index):
208
+ return f"static_cast<{INDEX_TYPE}>({cexpr(index)})"
209
+
210
+
211
+ def value_to_cpp(value, cpp_type):
212
+ if value == float("-inf"):
213
+ return f"-std::numeric_limits<{cpp_type}>::infinity()"
214
+ elif value == float("inf"):
215
+ return f"std::numeric_limits<{cpp_type}>::infinity()"
216
+ elif isinstance(value, bool):
217
+ return f"static_cast<{cpp_type}>({str(value).lower()})"
218
+ elif math.isnan(value):
219
+ return f"std::numeric_limits<{cpp_type}>::quiet_NaN()"
220
+ else:
221
+ return f"static_cast<{cpp_type}>({repr(value)})"
222
+
223
+
224
+ def rewrite_index_for_function(
225
+ localize_buffer_handler: "LocalizeBufferHandler",
226
+ index: sympy.Expr,
227
+ global_buf_name: str,
228
+ ):
229
+ # Local buffer at the inner dimensions
230
+ snode = V.graph.scheduler.name_to_buf[global_buf_name].defining_op
231
+ assert snode is not None
232
+ local_buf = localize_buffer_handler.global_to_local[global_buf_name]
233
+ scheduler_nodes = snode.get_nodes()
234
+ _, (group, reduction_group) = max(
235
+ scheduler_nodes, key=lambda x: int(x.is_reduction())
236
+ ).group
237
+ call_ranges = tuple(group) + tuple(reduction_group)
238
+ indices_to_keep = [
239
+ f"x{len(call_ranges) - (idx + 1)}"
240
+ for idx in range(len(local_buf.get_layout().size))
241
+ ]
242
+ sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined]
243
+ replacements = {}
244
+ for x in sorted_symbols:
245
+ if x.name.startswith("x") and x.name not in indices_to_keep: # type: ignore[attr-defined]
246
+ # Only keep index used by local buffer
247
+ replacements[x] = sympy.core.numbers.Zero()
248
+ index = sympy_subs(index, replacements) # type: ignore[arg-type]
249
+ return index
250
+
251
+
252
+ def rewrite_index_for_nodes(
253
+ localize_buffer_handler: "LocalizeBufferHandler",
254
+ index: sympy.Expr,
255
+ global_buf_name: str,
256
+ ):
257
+ used_vars = OrderedSet(
258
+ s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)
259
+ )
260
+ index_vars = []
261
+ local_buf = localize_buffer_handler.global_to_local[global_buf_name]
262
+ for i in range(len(local_buf.get_size())):
263
+ var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
264
+ index_vars.append(var if var in used_vars else 0)
265
+ index = local_buf.get_layout().make_indexer()(index_vars)
266
+ return index
267
+
268
+
269
+ class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
270
+ def __init__(
271
+ self,
272
+ inner,
273
+ global_to_local: dict[str, ir.Buffer],
274
+ rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr],
275
+ ) -> None:
276
+ super().__init__(inner)
277
+ self.global_to_local = global_to_local
278
+ self.rewrite_index = rewrite_index
279
+
280
+ def localize(self, name: str, index: sympy.Expr):
281
+ if self.global_to_local and name in self.global_to_local:
282
+ assert self.rewrite_index is not None
283
+ index = self.rewrite_index(self, index, name)
284
+ name = self.global_to_local[name].get_name()
285
+ return name, index
286
+
287
+ def load(self, name: str, index: sympy.Expr):
288
+ return self._inner.load(*self.localize(name, index))
289
+
290
+ def store(self, name, index, value, mode=None):
291
+ local_buffer_name, local_buffer_index = self.localize(name, index)
292
+ res = self._inner.store(local_buffer_name, local_buffer_index, value, mode)
293
+ if (
294
+ self.global_to_local
295
+ and name in self.global_to_local
296
+ and isinstance(V.kernel, Kernel)
297
+ ):
298
+ # Remove name of local buffer from Kernel.store_buffer_names
299
+ # local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store.
300
+ V.kernel.store_buffer_names.discard(local_buffer_name)
301
+ return res
302
+
303
+ def store_reduction(self, name, index, value):
304
+ return self._inner.store_reduction(*self.localize(name, index), value)
305
+
306
+
307
+ class LocalBufferContext:
308
+ """
309
+ This class creates a context that helps to generate code involving Inductor IR with
310
+ function local buffers. These buffers are constructed during the codegen process and
311
+ are used to store intermediate results such as local accumulators. We do not want to
312
+ add them to `V.graph` since they are not global and we do not want to add them as
313
+ function arguments either. So we patch the codegen processes under this scope to support
314
+ these buffers without exposure to the outside world.
315
+ """
316
+
317
+ def __init__(self, kernel_args: KernelArgs) -> None:
318
+ self.kernel_args = kernel_args
319
+ self.exit_stack = contextlib.ExitStack()
320
+ # map local buffer name to local buffer
321
+ self.local_buffers: dict[str, ir.Buffer] = {}
322
+ # map global buffer name to global buffer
323
+ self.global_buffers: dict[str, ir.Buffer] = {}
324
+ # map global buffer name to local buffer
325
+ self.global_to_local: dict[str, ir.Buffer] = {}
326
+ # record the global buffers that are removed by this LocalBufferContext
327
+ self.removed_buffers: OrderedSet[str] = OrderedSet()
328
+
329
+ def __enter__(self):
330
+ self.exit_stack.__enter__()
331
+ original_get_dtype = V.graph.get_dtype
332
+
333
+ def get_dtype(name):
334
+ if name in self.local_buffers:
335
+ return self.local_buffers[name].get_dtype()
336
+ return original_get_dtype(name)
337
+
338
+ self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype))
339
+
340
+ original_input = self.kernel_args.input
341
+
342
+ def input(name):
343
+ if name in self.local_buffers:
344
+ return name
345
+ return original_input(name)
346
+
347
+ self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input))
348
+
349
+ original_output = self.kernel_args.output
350
+
351
+ def output(name):
352
+ if name in self.local_buffers:
353
+ return name
354
+ return original_output(name)
355
+
356
+ self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output))
357
+
358
+ # Set current LocalBufferContext into V
359
+ self.exit_stack.enter_context(V.set_local_buffer_context(self))
360
+
361
+ return self
362
+
363
+ def __exit__(self, exc_type, exc_val, exc_tb):
364
+ self.local_buffers.clear()
365
+ self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
366
+
367
+ def add_local_buffer(
368
+ self, local_buffer: ir.Buffer, global_buffers: Optional[list[ir.Buffer]] = None
369
+ ):
370
+ assert local_buffer.get_name() not in self.local_buffers
371
+ self.local_buffers[local_buffer.get_name()] = local_buffer
372
+ if global_buffers:
373
+ for global_buffer in global_buffers:
374
+ global_buffer_name = global_buffer.get_name()
375
+ assert (
376
+ global_buffer_name not in self.global_buffers
377
+ and global_buffer_name not in self.global_to_local
378
+ )
379
+ self.global_buffers[global_buffer_name] = global_buffer
380
+ self.global_to_local[global_buffer_name] = local_buffer
381
+ if global_buffer_name not in V.graph.removed_buffers:
382
+ # Record the global buffers that are removed by this LocalBufferContext
383
+ # since which may need to restore. Refer to issue:
384
+ # https://github.com/pytorch/pytorch/issues/144186
385
+ self.removed_buffers.add(global_buffer_name)
386
+ V.graph.removed_buffers.add(global_buffer_name)
387
+
388
+ def localize_function(
389
+ self,
390
+ fn: Callable[..., Any],
391
+ rewrite_index: Callable[
392
+ ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr
393
+ ] = rewrite_index_for_function,
394
+ ):
395
+ def inner(*args, **kwargs):
396
+ with V.set_ops_handler(
397
+ LocalizeBufferHandler(
398
+ V.get_ops_handler(),
399
+ global_to_local=self.global_to_local,
400
+ rewrite_index=rewrite_index,
401
+ )
402
+ ):
403
+ return fn(*args, **kwargs)
404
+
405
+ return inner
406
+
407
+ def localize_nodes(
408
+ self,
409
+ nodes: list[ir.IRNode],
410
+ rewrite_index: Callable[
411
+ ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr
412
+ ] = rewrite_index_for_nodes,
413
+ ) -> list[ir.IRNode]:
414
+ """
415
+ Given `local_buf` and `global_buf` registered in current `LocalBufferContext`
416
+ though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf`
417
+ for the given `nodes` and returns a new list of IR nodes that work on `local_buf`
418
+ instead of `global_buf`, i.e., all the loads and stores are redirected to
419
+ `local_buf`. This helps the fused loops to work on smaller-sized local buffers
420
+ for better data locality.
421
+
422
+ The the data access of `local_buf` is assumed to be contiguous with the
423
+ same order as the `global_buf`.
424
+ """
425
+ assert len(nodes) > 0
426
+
427
+ def wrap_inner_fn_for_node(node: ir.IRNode):
428
+ loops = node.data if isinstance(node, ir.ComputedBuffer) else node
429
+ assert isinstance(loops, ir.Loops)
430
+ new_inner_fn = self.localize_function(
431
+ loops.inner_fn,
432
+ rewrite_index,
433
+ )
434
+
435
+ new_loops = dataclasses.replace(loops, inner_fn=new_inner_fn)
436
+ if isinstance(node, ir.ComputedBuffer):
437
+ new_node = ir.ComputedBuffer(
438
+ name=node.get_name(), layout=node.get_layout(), data=new_loops
439
+ )
440
+ else:
441
+ new_node = new_loops # type: ignore[assignment]
442
+
443
+ return new_node
444
+
445
+ return [wrap_inner_fn_for_node(node) for node in nodes]
446
+
447
+
448
+ def unify_mask_base_type(
449
+ buffer: IndentedBuffer,
450
+ vars: tuple[CSEVariable, ...],
451
+ dtype=torch.float,
452
+ ):
453
+ """
454
+ Given list of cse variables,
455
+ Cast each to new mask base dtype and return casted cse variable.
456
+ """
457
+ new_vars = (
458
+ V.kernel.cse.generate(
459
+ buffer,
460
+ f"{V.kernel._get_mask_cast(var, dtype)}",
461
+ )
462
+ for var in vars
463
+ )
464
+ return new_vars
465
+
466
+
467
+ def may_unify_binary_op_mask_type(a, b):
468
+ """
469
+ Given two cse variables, when dtype is bool, unify them to the same mask dtype and return casted cse variable.
470
+ """
471
+ if a.dtype == torch.bool:
472
+ assert b.dtype == torch.bool
473
+ mask_dtype = torch.int32
474
+ return unify_mask_base_type(V.kernel.compute, (a, b), mask_dtype)
475
+ return a, b
476
+
477
+
478
+ def codegen_rand(offset, code, rand_function, dst_dtype=torch.float32):
479
+ assert is_integer_dtype(offset.dtype)
480
+ code.writeline("[&]()")
481
+ with code.indent():
482
+ code.writeline(
483
+ f"{DTYPE_TO_CPP[offset.dtype]} offset[{V.kernel.tiling_factor}];"
484
+ )
485
+ code.writeline(f"{DTYPE_TO_CPP[dst_dtype]} result[{V.kernel.tiling_factor}];")
486
+ code.writeline(f"{offset}.store(offset);")
487
+ code.writeline(
488
+ f"for( {DTYPE_TO_CPP[offset.dtype]} offset_idx = 0; offset_idx < {V.kernel.tiling_factor}; offset_idx++ )"
489
+ )
490
+ with code.indent():
491
+ code.writeline(rand_function)
492
+ num_vectors = V.kernel._get_num_vectors(dtype=dst_dtype)
493
+ if num_vectors == 1:
494
+ code.writeline(
495
+ f"return at::vec::Vectorized<{DTYPE_TO_CPP[dst_dtype]}>::loadu(result);"
496
+ )
497
+ else:
498
+ code.writeline(
499
+ f"return at::vec::VectorizedN<{DTYPE_TO_CPP[dst_dtype]}, {num_vectors}>::loadu(result);"
500
+ )
501
+ code.writeline("()")
502
+ return code
503
+
504
+
505
+ def get_gemm_template_output_and_compute_dtype(input_dtype):
506
+ if input_dtype in [torch.uint8, torch.int8]:
507
+ return (torch.int32, torch.int32)
508
+ else:
509
+ return (torch.float32, torch.float32)
510
+
511
+
512
+ def create_epilogue_with_attr(input_buffer, attr, **kwargs):
513
+ input_loader = input_buffer.make_loader()
514
+ dtype = input_buffer.get_dtype()
515
+ if attr == "relu":
516
+
517
+ def inner_fn(index):
518
+ input = input_loader(index)
519
+ zero = ops.constant(0, dtype)
520
+ return ops.maximum(input, zero)
521
+
522
+ elif attr == "gelu":
523
+ assert "algorithm" in kwargs
524
+ if kwargs["algorithm"] == "none":
525
+
526
+ def inner_fn(index):
527
+ input = input_loader(index)
528
+ if dtype != torch.float:
529
+ input = ops.to_dtype(input, torch.float)
530
+ half = ops.constant(0.5, torch.float)
531
+ one = ops.constant(1.0, torch.float)
532
+ const = ops.constant(0.7071067811865476, torch.float)
533
+ result = input * half * (ops.erf(input * const) + one)
534
+ if dtype != torch.float:
535
+ result = ops.to_dtype(result, dtype)
536
+ return result
537
+
538
+ else:
539
+ assert kwargs["algorithm"] == "tanh"
540
+
541
+ def inner_fn(index):
542
+ input = input_loader(index)
543
+ if dtype != torch.float:
544
+ input = ops.to_dtype(input, torch.float)
545
+ half = ops.constant(0.5, torch.float)
546
+ one = ops.constant(1.0, torch.float)
547
+ const1 = ops.constant(0.7978845608028654, torch.float)
548
+ const2 = ops.constant(0.044715, torch.float)
549
+ result = (
550
+ half
551
+ * input
552
+ * (
553
+ one
554
+ + ops.tanh(const1 * (input + const2 * input * input * input))
555
+ )
556
+ )
557
+ if dtype != torch.float:
558
+ result = ops.to_dtype(result, dtype)
559
+ return result
560
+
561
+ elif attr == "swish":
562
+
563
+ def inner_fn(index):
564
+ input = input_loader(index)
565
+ result = input * ops.sigmoid(input)
566
+ return result
567
+
568
+ elif attr == "sigmoid":
569
+
570
+ def inner_fn(index):
571
+ return ops.sigmoid(input_loader(index))
572
+
573
+ elif attr == "tanh":
574
+
575
+ def inner_fn(index):
576
+ return ops.tanh(input_loader(index))
577
+
578
+ elif attr == "hardswish" or attr == "hardsigmoid":
579
+
580
+ def hardsigmoid_float(input):
581
+ zero = ops.constant(0, torch.float)
582
+ six = ops.constant(6, torch.float)
583
+ three = ops.constant(3, torch.float)
584
+ one_over_six = ops.constant(0.16666666666666666, torch.float)
585
+ max = ops.maximum(input + three, zero)
586
+ min = ops.minimum(max, six)
587
+ return min * one_over_six
588
+
589
+ def inner_fn(index):
590
+ input = input_loader(index)
591
+ if dtype != torch.float:
592
+ input = ops.to_dtype(input, torch.float)
593
+ result = hardsigmoid_float(input)
594
+ if attr == "hardswish":
595
+ result = input * result
596
+ if dtype != torch.float:
597
+ result = ops.to_dtype(result, dtype)
598
+ return result
599
+
600
+ elif attr == "leaky_relu":
601
+ assert "scalars" in kwargs
602
+ assert len(kwargs["scalars"]) == 1
603
+ negative_slope = kwargs["scalars"][0]
604
+
605
+ def inner_fn(index):
606
+ input = input_loader(index)
607
+ if dtype != torch.float:
608
+ input = ops.to_dtype(input, torch.float)
609
+ zero = ops.constant(0, torch.float)
610
+ result = ops.where(
611
+ input > zero, input, input * ops.constant(negative_slope, torch.float)
612
+ )
613
+ if dtype != torch.float:
614
+ result = ops.to_dtype(result, dtype)
615
+ return result
616
+
617
+ elif attr == "hardtanh":
618
+ assert "scalars" in kwargs
619
+ assert len(kwargs["scalars"]) == 2
620
+ min_value = kwargs["scalars"][0]
621
+ max_value = kwargs["scalars"][1]
622
+
623
+ def inner_fn(index):
624
+ input = input_loader(index)
625
+ if dtype != torch.float:
626
+ input = ops.to_dtype(input, torch.float)
627
+ result = ops.minimum(
628
+ ops.maximum(input, ops.constant(min_value, torch.float)),
629
+ ops.constant(max_value, torch.float),
630
+ )
631
+ if dtype != torch.float:
632
+ result = ops.to_dtype(result, dtype)
633
+ return result
634
+
635
+ elif attr in ["add", "sub", "mul"]:
636
+ assert "other" in kwargs
637
+ other = kwargs["other"]
638
+ num_input_dims = len(input_buffer.get_size())
639
+ num_other_dims = len(other.get_size())
640
+ dims_diff = num_input_dims - num_other_dims
641
+ other_loader = other.make_loader()
642
+
643
+ def inner_fn(index):
644
+ op = getattr(ops, attr)
645
+ if dims_diff != 0:
646
+ return op(input_loader(index), other_loader(index[dims_diff:]))
647
+ else:
648
+ return op(input_loader(index), other_loader(index))
649
+
650
+ elif attr == "bias_add":
651
+ assert "other" in kwargs
652
+ assert "beta" in kwargs
653
+ assert "dtype" in kwargs
654
+ beta = kwargs["beta"]
655
+ other = kwargs["other"]
656
+ dtype = kwargs["dtype"]
657
+ bias_loader = other.make_loader()
658
+
659
+ def inner_fn(index):
660
+ bias = bias_loader(index)
661
+ input = input_loader(index)
662
+ if beta != 1:
663
+ result = ops.constant(beta, torch.float) * bias + input
664
+ else:
665
+ result = bias + input
666
+ return result
667
+
668
+ else:
669
+ raise ValueError(f"Unsupported epilogue attribute: {attr}")
670
+ return ir.Pointwise(
671
+ device=input_buffer.get_device(),
672
+ dtype=dtype,
673
+ inner_fn=inner_fn,
674
+ ranges=input_buffer.get_size(),
675
+ )
676
+
677
+
678
+ def _get_loop_body(fn_list):
679
+ if all(isinstance(fn, LoopBody) for fn in fn_list):
680
+ loop_bodies = fn_list
681
+ else:
682
+ if hasattr(fn_list[0], "original_fn"):
683
+ # For the case of local buffer, we wrap the fn with localize_function
684
+ assert all(hasattr(fn, "original_fn") for fn in fn_list)
685
+ assert all(
686
+ isinstance(fn.original_fn.args[0]._body, LoopBody) for fn in fn_list
687
+ )
688
+ loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list]
689
+ else:
690
+ assert all(isinstance(fn, functools.partial) for fn in fn_list)
691
+ assert all(isinstance(fn.args[0]._body, LoopBody) for fn in fn_list)
692
+ loop_bodies = [fn.args[0]._body for fn in fn_list]
693
+ assert loop_bodies is not None
694
+ return loop_bodies
695
+
696
+
697
+ def _get_dtype_from_loopbodies(loop_bodies):
698
+ dtypes = OrderedSet[torch.dtype]()
699
+ for loop_body in loop_bodies:
700
+ graphs = [loop_body.root_block.graph] + [
701
+ body.graph for body in list(loop_body.subblocks.values())
702
+ ]
703
+ for graph in graphs:
704
+ for node in graph.nodes:
705
+ if node.op != "call_method":
706
+ continue
707
+ dtypes.add(node.meta[OptimizationContext.key].dtype)
708
+ return dtypes
709
+
710
+
711
+ def template_fusion_with_epilogues_supported(
712
+ template: BaseSchedulerNode, epilogues: list[BaseSchedulerNode]
713
+ ) -> tuple[bool, bool]:
714
+ def _get_indexes_of_template_buf_read(
715
+ epilogue_node: ir.Operation, template_buf_names: list[str]
716
+ ) -> list[sympy.Expr]:
717
+ return [
718
+ read.index
719
+ for read in epilogue_node.get_reads()
720
+ if read.name in template_buf_names
721
+ ]
722
+
723
+ def _check_supported_and_same_indexes(
724
+ index_of_template_buf_read: Sequence[sympy.Expr],
725
+ epilogue_writes: OrderedSet[Dep],
726
+ ) -> tuple[bool, bool]:
727
+ num_indexes = len(OrderedSet(index_of_template_buf_read))
728
+
729
+ if num_indexes > 1:
730
+ same_index = False
731
+ supported = False # Different read indexes not supported
732
+ elif num_indexes == 0:
733
+ same_index = True
734
+ supported = True # No reads, automatically supported
735
+ elif num_indexes == 1:
736
+ iotbr = index_of_template_buf_read[0]
737
+ same_index = all(write.index == iotbr for write in epilogue_writes)
738
+ # TODO: Add support of fusion when the read of template buffer and the write of epilogue output
739
+ # in the epilogue node don't have the same index and change supported to True
740
+ supported = same_index
741
+ else:
742
+ raise AssertionError("Should not reach here")
743
+
744
+ return supported, same_index
745
+
746
+ def _template_fusion_supported(
747
+ template_outputs: Sequence[SchedulerBuffer], epilogue_nodes: list[ir.Operation]
748
+ ) -> tuple[bool, bool]:
749
+ template_buf_names = [x.get_name() for x in template_outputs]
750
+ indexes_of_template_buf_reads = [
751
+ _get_indexes_of_template_buf_read(epilogue_node, template_buf_names)
752
+ for epilogue_node in epilogue_nodes
753
+ ]
754
+ epilogue_nodes_writes = [
755
+ epilogue_node.get_read_writes().writes for epilogue_node in epilogue_nodes
756
+ ]
757
+
758
+ results = [
759
+ _check_supported_and_same_indexes(reads, writes)
760
+ for reads, writes in zip(
761
+ indexes_of_template_buf_reads, epilogue_nodes_writes
762
+ )
763
+ ]
764
+ supported, same_indexes = zip(*results)
765
+ return all(supported), all(same_indexes)
766
+
767
+ assert template.is_template()
768
+ template_outputs = template.get_outputs()
769
+
770
+ epilogue_nodes = [
771
+ n.node
772
+ for epilogue in epilogues
773
+ for n in epilogue.get_nodes()
774
+ if n.node is not None
775
+ ]
776
+ return _template_fusion_supported(template_outputs, epilogue_nodes)
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from collections.abc import Sequence
3
+ from typing import Any, Callable, Optional, Union
4
+
5
+ import sympy
6
+
7
+ import torch
8
+ import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
9
+ import torch._ops
10
+
11
+ from .. import config, ir
12
+ from ..utils import sympy_product
13
+ from ..virtualized import V
14
+ from .cpp_utils import DTYPE_TO_CPP
15
+ from .cpp_wrapper_cpu import CppWrapperCpu
16
+ from .wrapper import (
17
+ BufferLike,
18
+ EnterSubgraphLine,
19
+ ExitSubgraphLine,
20
+ MemoryPlanningLine,
21
+ MemoryPlanningState,
22
+ PythonWrapperCodegen,
23
+ )
24
+
25
+
26
+ BufferName = str
27
+
28
+ # Default thread stack sizes vary by platform:
29
+ # - Linux: 8 MB
30
+ # - macOS: 512 KB
31
+ # - Windows: 1 MB
32
+ # Just pick something comfortably smaller than the smallest for now.
33
+ MAX_STACK_ALLOCATION_SIZE = 1024 * 100
34
+
35
+
36
+ class CppWrapperCpuArrayRef(CppWrapperCpu):
37
+ """
38
+ Generates cpp wrapper for running on CPU and calls cpp kernels
39
+
40
+ This class is forked from CppWrapperCpu, with a difference that tensors may be
41
+ represented as ArrayRef, see torch/csrc/inductor/aoti_runtime/arrayref_tensor.h
42
+ """
43
+
44
+ def __init__(self):
45
+ super().__init__()
46
+ assert self.device == "cpu", "ArrayRefTensor only supported on CPU!"
47
+ self.allow_stack_allocation = config.aot_inductor.allow_stack_allocation
48
+ self.stack_allocated_buffers: dict[BufferName, BufferLike] = {}
49
+
50
+ @staticmethod
51
+ def create(
52
+ is_subgraph: bool,
53
+ subgraph_name: Optional[str],
54
+ parent_wrapper: Optional[PythonWrapperCodegen],
55
+ partition_signatures: Optional[ir.GraphPartitionSignature] = None,
56
+ ):
57
+ # TODO - support subgraph codegen by lifting functions. Check the
58
+ # comment at CppWrapperCpu `codegen_subgraph` function.
59
+ return CppWrapperCpuArrayRef()
60
+
61
+ @staticmethod
62
+ def get_input_cpp_type(input):
63
+ assert config.aot_inductor.use_minimal_arrayref_interface
64
+
65
+ if isinstance(input, sympy.Expr):
66
+ from ..graph import may_get_constant_buffer_dtype
67
+
68
+ dtype = may_get_constant_buffer_dtype(input)
69
+ assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}"
70
+ return DTYPE_TO_CPP[dtype]
71
+ return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>"
72
+
73
+ @staticmethod
74
+ def get_device_include_path(device: str) -> str:
75
+ assert device == "cpu", "ArrayRef only supported on CPU!"
76
+ if V.graph.aot_mode:
77
+ return "#include <torch/csrc/inductor/aoti_include/array_ref.h>"
78
+ return "#include <torch/csrc/inductor/cpp_wrapper/array_ref.h>"
79
+
80
+ def codegen_input_numel_asserts(self):
81
+ for name, buf in V.graph.graph_inputs.items():
82
+ if isinstance(buf, sympy.Expr):
83
+ continue
84
+
85
+ # comparing strides for 0 size tensor is tricky. Ignore them for now.
86
+ if sympy_product(buf.get_size()) == 0:
87
+ continue
88
+ numel = buf.get_numel()
89
+ self.prefix.writeline(f"assert_numel({name}, {numel});")
90
+
91
+ def generate_extern_kernel_alloc(self, *args, **kwargs):
92
+ # Disable stack allocation for extern kernels.
93
+ self.allow_stack_allocation = False
94
+ super().generate_extern_kernel_alloc(*args, **kwargs)
95
+
96
+ def generate_extern_kernel_out(self, *args, **kwargs):
97
+ # Disable stack allocation for extern kernels.
98
+ self.allow_stack_allocation = False
99
+ super().generate_extern_kernel_out(*args, **kwargs)
100
+
101
+ def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None:
102
+ # Disable stack allocation for extern kernels.
103
+ self.allow_stack_allocation = False
104
+ super().generate_fallback_kernel(node)
105
+
106
+ def _generate_kernel_call_helper(
107
+ self,
108
+ kernel_name: str,
109
+ call_args,
110
+ *,
111
+ device=None,
112
+ triton=True,
113
+ arg_types=None,
114
+ raw_keys=None,
115
+ raw_args=None,
116
+ triton_meta=None,
117
+ graph_name="",
118
+ original_fxnode_name=None,
119
+ ):
120
+ """
121
+ Generates kernel call code.
122
+
123
+ triton: Defines whether the GPU backend uses Triton for codegen.
124
+ Otherwise it uses the CUDA language for codegen.
125
+ Only valid when cuda == True.
126
+ """
127
+ assert not triton, (
128
+ "CppWrapperCpuArrayRef.generate_kernel_call does not support GPU"
129
+ )
130
+ assert arg_types is not None and len(call_args) == len(arg_types), (
131
+ "Mismatch call_args and arg_types in generate_kernel_call"
132
+ )
133
+ new_args = []
134
+ for idx, arg in enumerate(call_args):
135
+ if "*" in arg_types[idx]:
136
+ var_name = f"var_{next(self.arg_var_id)}"
137
+ self.writeline(f"auto* {var_name} = get_data_ptr_wrapper({arg});")
138
+ new_args.append(f"({arg_types[idx]})({var_name})")
139
+ else:
140
+ # arg is a scalar
141
+ new_args.append(arg)
142
+ # debug printer related logic for cpp kernel type.
143
+ debug_printer_manager = V.graph.wrapper_code.debug_printer
144
+ debug_printer_manager.set_printer_args(
145
+ call_args,
146
+ kernel_name,
147
+ None,
148
+ None,
149
+ "cpp",
150
+ )
151
+ with debug_printer_manager:
152
+ self.writeline(self.wrap_kernel_call(kernel_name, new_args))
153
+
154
+ def write_wrapper_decl(self):
155
+ inputs_len = len(V.graph.graph_inputs.keys())
156
+ if V.graph.aot_mode:
157
+ if (
158
+ config.aot_inductor.use_minimal_arrayref_interface
159
+ and not V.graph.is_const_graph
160
+ ):
161
+ input_cpp_types = ", ".join(
162
+ f"{CppWrapperCpuArrayRef.get_input_cpp_type(x)}"
163
+ for x in V.graph.graph_inputs.values()
164
+ )
165
+ output_arrayref_types = ", ".join(
166
+ f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>"
167
+ for x in V.graph.graph_outputs
168
+ )
169
+
170
+ self.prefix.splice(
171
+ f"""
172
+ using AOTInductorModelInputs = std::tuple<{input_cpp_types}>;
173
+ using AOTInductorModelOutputs = std::tuple<{output_arrayref_types}>;
174
+ """
175
+ )
176
+
177
+ if V.graph.const_module:
178
+ self.header.splice(V.graph.const_module.wrapper_code.header)
179
+
180
+ assert V.graph.const_wrapper_code is not None
181
+ self.prefix.splice(V.graph.const_wrapper_code)
182
+
183
+ assert V.graph.const_kernel_code is not None
184
+ self.kernel_declarations.splice(V.graph.const_kernel_code)
185
+
186
+ if V.graph.is_const_graph:
187
+ self.prefix.splice(
188
+ """
189
+ void AOTInductorModel::_const_run_impl(
190
+ std::vector<AtenTensorHandle>& output_handles,
191
+ DeviceStreamType stream,
192
+ AOTIProxyExecutorHandle proxy_executor
193
+ ) {
194
+ """
195
+ )
196
+ else:
197
+ if not config.aot_inductor.use_runtime_constant_folding:
198
+ # If we do not split the constant graph, we'll just create
199
+ # an empty implementation when wrapping the main module.
200
+ self.prefix.splice(
201
+ """
202
+ void AOTInductorModel::_const_run_impl(
203
+ std::vector<AtenTensorHandle>& output_handles,
204
+ DeviceStreamType stream,
205
+ AOTIProxyExecutorHandle proxy_executor
206
+ ) {}
207
+
208
+ """
209
+ )
210
+
211
+ run_impl_proto = """
212
+ void AOTInductorModel::run_impl(
213
+ AtenTensorHandle*
214
+ input_handles, // array of input AtenTensorHandle; handles
215
+ // are stolen; the array itself is borrowed
216
+ AtenTensorHandle*
217
+ output_handles, // array for writing output AtenTensorHandle; handles
218
+ // will be stolen by the caller; the array itself is
219
+ // borrowed
220
+ DeviceStreamType stream,
221
+ AOTIProxyExecutorHandle proxy_executor
222
+ ) {
223
+ """
224
+
225
+ self.generate_input_output_runtime_checks()
226
+ run_impl_proto += """
227
+ __check_inputs_outputs(input_handles, output_handles);
228
+ """
229
+
230
+ if config.aot_inductor.use_minimal_arrayref_interface:
231
+ self.prefix.splice(
232
+ """
233
+ template <>
234
+ AOTInductorModelOutputs AOTInductorModel::run_impl_minimal_arrayref_interface<
235
+ AOTInductorModelInputs, AOTInductorModelOutputs>(
236
+ const AOTInductorModelInputs& inputs,
237
+ DeviceStreamType stream,
238
+ AOTIProxyExecutorHandle proxy_executor
239
+ ) {
240
+ """
241
+ )
242
+ self.suffix.splice(run_impl_proto)
243
+ self.suffix.splice(
244
+ """
245
+ AOTInductorModelInputs inputs;
246
+ convert_handles_to_inputs(input_handles, inputs);
247
+ auto outputs = run_impl_minimal_arrayref_interface<AOTInductorModelInputs, AOTInductorModelOutputs>(
248
+ inputs, stream, proxy_executor);
249
+ // NOTE: outputs is full of ArrayRef to thread_local storage. If in the future we need this
250
+ // interface to perform well for a DSO using the minimal arrayref interface, all we need
251
+ // to do is provide ThreadLocalCachedTensor for each one!
252
+ convert_outputs_to_handles(outputs, output_handles);
253
+ }
254
+ """
255
+ )
256
+
257
+ self.suffix.splice(
258
+ """
259
+ extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterface(
260
+ AOTInductorModelHandle model_handle,
261
+ const AOTInductorModelInputs& inputs,
262
+ AOTInductorModelOutputs& outputs) {
263
+ auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
264
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
265
+ outputs = model->run_impl_minimal_arrayref_interface<AOTInductorModelInputs, AOTInductorModelOutputs>(
266
+ inputs,
267
+ (torch::aot_inductor::DeviceStreamType)nullptr,
268
+ nullptr);
269
+ })
270
+ }
271
+ """
272
+ )
273
+ else:
274
+ self.prefix.splice(run_impl_proto)
275
+ else:
276
+ # cpp entry function for JIT with cpp wrapper
277
+ self.prefix.splice(
278
+ """
279
+ void inductor_entry_impl(
280
+ AtenTensorHandle*
281
+ input_handles, // array of input AtenTensorHandle; handles
282
+ // are stolen; the array itself is borrowed
283
+ AtenTensorHandle*
284
+ output_handles // array for writing output AtenTensorHandle; handles
285
+ // will be stolen by the caller; the array itself is
286
+ // borrowed)
287
+ ) {
288
+ """
289
+ )
290
+ with self.prefix.indent():
291
+ # assign inputs and outputs in both cases so the later codegen can be simplified
292
+ if not config.aot_inductor.use_minimal_arrayref_interface:
293
+ if not V.graph.is_const_graph:
294
+ if V.graph.aot_mode:
295
+ num_args = len(V.graph.graph_inputs)
296
+ else:
297
+ # Weights are promoted in the JIT mode
298
+ num_args = len(V.graph.graph_inputs) + len(V.graph.constants)
299
+ # release GIL to support multiple instances inference (in different threads of the same process)
300
+ self.prefix.splice("py::gil_scoped_release release;")
301
+
302
+ self.prefix.splice(
303
+ f"""
304
+ auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args});
305
+ """
306
+ )
307
+
308
+ if inputs_len != 0:
309
+ for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
310
+ if config.aot_inductor.use_minimal_arrayref_interface:
311
+ self.prefix.writeline(
312
+ f"auto {input_key} = std::get<{idx}>(inputs);"
313
+ )
314
+ continue
315
+ # unwrap input tensor back to scalar
316
+ if isinstance(V.graph.graph_inputs[input_key], sympy.Expr):
317
+ from ..graph import may_get_constant_buffer_dtype
318
+
319
+ dtype = may_get_constant_buffer_dtype(
320
+ V.graph.graph_inputs[input_key] # type: ignore[arg-type]
321
+ )
322
+ assert dtype is not None, (
323
+ "Fails to get the dtype of the sympy.Expr"
324
+ )
325
+ self.codegen_tensor_item(
326
+ dtype, f"inputs[{idx}]", input_key, self.prefix
327
+ )
328
+ else:
329
+ self.prefix.writeline(
330
+ f"auto {input_key} = std::move(inputs[{idx}]);"
331
+ )
332
+
333
+ assert all(
334
+ isinstance(v, torch.Tensor) for v in list(V.graph.constants.values())
335
+ ), "Expect all constants to be Tensor"
336
+ for idx, constants_key in enumerate(V.graph.constants.keys()):
337
+ if V.graph.aot_mode:
338
+ # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there.
339
+ # Don't call std::move here because it will cause constants_ to lose the ownership.
340
+ self.prefix.writeline(
341
+ f"""auto {constants_key} = constants_->at({idx});"""
342
+ )
343
+ else:
344
+ # Append constants as inputs to the graph
345
+ constants_idx = inputs_len + idx
346
+ self.prefix.writeline(
347
+ f"auto {constants_key} = std::move(inputs[{constants_idx}]);"
348
+ )
349
+
350
+ self.codegen_inputs()
351
+
352
+ if V.graph.aot_mode:
353
+ if not V.graph.is_const_graph:
354
+ if config.aot_inductor.use_minimal_arrayref_interface:
355
+ # TODO: input shape checking for regular tensor interface as well?
356
+ self.codegen_input_numel_asserts()
357
+ else:
358
+ self.prefix.writeline("inputs.clear();")
359
+ self.prefix.writeline(
360
+ "[[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());"
361
+ )
362
+
363
+ def generate_return(self, output_refs: list[str]):
364
+ cst_names = V.graph.constants.keys()
365
+ arr_iface = (
366
+ not V.graph.is_const_graph
367
+ and config.aot_inductor.use_minimal_arrayref_interface
368
+ ) # For brevity.
369
+
370
+ def use_thread_local_cached_output_tensor(idx, output):
371
+ cached_output_name = f"cached_output_{next(self.cached_output_id)}"
372
+ cache_type = "Array" if arr_iface else "Tensor"
373
+ self.wrapper_call.writeline(
374
+ f"thread_local ThreadLocalCachedOutput{cache_type}<std::decay_t<decltype({output})>> "
375
+ f"{cached_output_name}({output});"
376
+ )
377
+ if arr_iface:
378
+ self.wrapper_call.writeline(
379
+ f"{cached_output_name}.copy_data_from({output});"
380
+ )
381
+ output_entry = f"std::get<{idx}>(output_arrayref_tensors)"
382
+ element_type = f"std::decay_t<decltype({output_entry}.data()[0])>"
383
+ self.wrapper_call.writeline(
384
+ f"{output_entry} = {cached_output_name}.arrayref_tensor<{element_type}>();"
385
+ )
386
+ else:
387
+ self.wrapper_call.writeline(
388
+ f"{cached_output_name}.copy_data_from({output});"
389
+ )
390
+ self.wrapper_call.writeline(
391
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));"
392
+ )
393
+ self.wrapper_call.writeline(
394
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), "
395
+ f"output_handles[{idx}]));"
396
+ )
397
+
398
+ if arr_iface:
399
+ self.wrapper_call.writeline(
400
+ "AOTInductorModelOutputs output_arrayref_tensors;"
401
+ )
402
+
403
+ output2idx: dict[str, int] = {}
404
+ for idx, output in enumerate(output_refs):
405
+ if output == "nullptr":
406
+ continue
407
+
408
+ is_constant_buffer = output in cst_names
409
+ output_buffer = V.graph.graph_outputs[idx]
410
+ if isinstance(output_buffer, ir.BaseView):
411
+ output_storage = output_buffer.unwrap_view()
412
+ if isinstance(output_storage.data, ir.ConstantBuffer):
413
+ is_constant_buffer = True
414
+
415
+ if isinstance(output_buffer, ir.ShapeAsConstantBuffer):
416
+ # Need to wrap scalar into tensor as the main function returns a vector of tensors
417
+ output_tensor = self.codegen_scalar_to_tensor(output)
418
+ self.wrapper_call.writeline(
419
+ f"output_handles[{idx}] = {output_tensor}.release();"
420
+ )
421
+ continue
422
+
423
+ output_is_tensor_handle_expr = (
424
+ f"std::is_same_v<std::decay_t<decltype({output})>,"
425
+ "RAIIAtenTensorHandle> || "
426
+ f"std::is_same_v<std::decay_t<decltype({output})>,"
427
+ "AtenTensorHandle> || "
428
+ f"std::is_same_v<std::decay_t<decltype({output})>,"
429
+ "ConstantHandle>"
430
+ )
431
+ self.wrapper_call.writeline(
432
+ f"if constexpr ({output_is_tensor_handle_expr}) {{"
433
+ )
434
+ with self.wrapper_call.indent():
435
+ if arr_iface:
436
+ cached_output_name = f"cached_output_{next(self.cached_output_id)}"
437
+ self.wrapper_call.writeline(
438
+ f"thread_local RAIIAtenTensorHandle {cached_output_name};"
439
+ )
440
+ if is_constant_buffer:
441
+ # NOTE(return_constant): In some rare cases where we return
442
+ # a constant, we have to return a copy of this constant,
443
+ # because (1) constants are not owned by the Model instance
444
+ # (2) constants remain the same cross inference runs,
445
+ # assuming they are not updated at runtime Basically, we
446
+ # cannot release or transfer the ownership of any original
447
+ # constant to the user.
448
+ self.wrapper_call.writeline(
449
+ f"AtenTensorHandle {cached_output_name}_tmp;"
450
+ )
451
+ self.wrapper_call.writeline(
452
+ f"aoti_torch_clone({output}, &{cached_output_name}_tmp);"
453
+ )
454
+ self.wrapper_call.writeline(
455
+ f"{cached_output_name} = {cached_output_name}_tmp;"
456
+ )
457
+ else:
458
+ self.wrapper_call.writeline(
459
+ f"{cached_output_name} = {output}.release();"
460
+ )
461
+ self.wrapper_call.writeline(
462
+ f"convert_handle_to_arrayref_tensor({cached_output_name}, "
463
+ f"std::get<{idx}>(output_arrayref_tensors));"
464
+ )
465
+ else:
466
+ if is_constant_buffer:
467
+ # See NOTE(return_constant) above.
468
+ self.wrapper_call.writeline(
469
+ f"aoti_torch_clone({output}, &output_handles[{idx}]);"
470
+ )
471
+ else:
472
+ if output in output2idx:
473
+ src_idx = output2idx[output]
474
+ self.wrapper_call.writeline(
475
+ f"output_handles[{idx}] = output_handles[{src_idx}];"
476
+ )
477
+ else:
478
+ self.wrapper_call.writeline(
479
+ f"output_handles[{idx}] = {output}.release();"
480
+ )
481
+ self.wrapper_call.writeline("} else {")
482
+ with self.wrapper_call.indent():
483
+ use_thread_local_cached_output_tensor(idx, output)
484
+ self.wrapper_call.writeline("}")
485
+
486
+ if output not in output2idx:
487
+ output2idx[output] = idx
488
+ if arr_iface:
489
+ self.wrapper_call.writeline("return output_arrayref_tensors;")
490
+
491
+ def memory_plan(self):
492
+ from .memory_planning import MemoryPlanner
493
+
494
+ self.lines = MemoryPlanner(self).plan(self.lines)
495
+ # TODO: integrate memory planning & stack allocation?
496
+ self.allow_stack_allocation = False
497
+
498
+ def memory_plan_reuse(self):
499
+ out_names = V.graph.get_output_names()
500
+
501
+ while (
502
+ self.lines
503
+ and isinstance(self.lines[-1], MemoryPlanningLine)
504
+ # TODO: this seems legit, NullLine has no node
505
+ and self.lines[-1].node.name not in out_names # type: ignore[attr-defined]
506
+ ):
507
+ # these lines will be pointless
508
+ self.lines.pop()
509
+
510
+ # codegen allocations in two passes
511
+ planning_states = [MemoryPlanningState()]
512
+ past_planning_states = []
513
+ for i in range(len(self.lines)):
514
+ line = self.lines[i]
515
+ if isinstance(line, MemoryPlanningLine):
516
+ self.lines[i] = line.plan(planning_states[-1])
517
+ elif isinstance(line, EnterSubgraphLine):
518
+ planning_states.append(MemoryPlanningState())
519
+ elif isinstance(line, ExitSubgraphLine):
520
+ past_planning_states.append(planning_states.pop())
521
+ past_planning_states.append(planning_states.pop())
522
+ assert len(planning_states) == 0
523
+
524
+ # conservatively use the sum of all allocated buffer sizes
525
+ # in potentially nested scopes as the total allocated size
526
+ total_allocated_buffer_size = sum(
527
+ s.total_allocated_buffer_size for s in past_planning_states
528
+ )
529
+
530
+ self.allow_stack_allocation = (
531
+ self.allow_stack_allocation is not False
532
+ and config.aot_inductor.allow_stack_allocation
533
+ and total_allocated_buffer_size <= MAX_STACK_ALLOCATION_SIZE
534
+ )
535
+
536
+ def can_stack_allocate_buffer(self, buffer):
537
+ return (
538
+ self.allow_stack_allocation
539
+ and buffer.get_device().type == "cpu"
540
+ and self.can_prove_buffer_has_static_shape(buffer)
541
+ and ir.is_contiguous_strides_for_shape(
542
+ buffer.get_stride(), buffer.get_size()
543
+ )
544
+ )
545
+
546
+ def make_buffer_free(self, buffer):
547
+ return (
548
+ ""
549
+ if isinstance(buffer.get_output_spec(), ir.MultiOutputLayout)
550
+ or (V.graph.aot_mode and buffer.get_name() in self.stack_allocated_buffers)
551
+ or (
552
+ config.aot_inductor.use_minimal_arrayref_interface
553
+ and V.graph.aot_mode
554
+ and buffer.get_name() in V.graph.graph_inputs
555
+ )
556
+ else f"{buffer.get_name()}.reset();"
557
+ )
558
+
559
+ def make_buffer_allocation(self, buffer):
560
+ return self.make_allocation(
561
+ buffer.get_name(),
562
+ buffer.get_device(),
563
+ buffer.get_dtype(),
564
+ buffer.get_size(),
565
+ buffer.get_stride(),
566
+ buffer if self.can_stack_allocate_buffer(buffer) else None,
567
+ )
568
+
569
+ def make_allocation(
570
+ self, name, device, dtype, shape, stride, buffer_if_can_stack_allocate=None
571
+ ):
572
+ orig_stride = stride
573
+ device_str = self.codegen_device(device)
574
+ dtype_code = self.codegen_dtype(dtype)
575
+ size = self.codegen_shape_tuple(shape)
576
+ stride = self.codegen_shape_tuple(orig_stride)
577
+ size_array_var = self.codegen_int_array_var(
578
+ size,
579
+ self.wrapper_call.writeline,
580
+ known_statically=self.is_statically_known_list_of_ints(shape),
581
+ graph=self.get_codegened_graph(),
582
+ )
583
+ stride_array_var = self.codegen_int_array_var(
584
+ stride,
585
+ self.wrapper_call.writeline,
586
+ known_statically=self.is_statically_known_list_of_ints(orig_stride),
587
+ graph=self.get_codegened_graph(),
588
+ )
589
+ device_type, device_id = device_str.split(",")
590
+ device_idx = "this->device_idx_" if V.graph.aot_mode else device_id
591
+ if buffer_if_can_stack_allocate is not None:
592
+ self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate
593
+ cpp_type = DTYPE_TO_CPP[dtype]
594
+ numel = buffer_if_can_stack_allocate.get_numel()
595
+ # Note: we don't zero storage because empty_strided doesn't zero either.
596
+ self.wrapper_call.writeline(f"{cpp_type} {name}_storage[{numel}];")
597
+ args = [
598
+ f"{name}_storage",
599
+ size_array_var,
600
+ stride_array_var,
601
+ device_type,
602
+ device_idx,
603
+ ]
604
+ return f"ArrayRefTensor<{cpp_type}> {name}({', '.join(args)});"
605
+
606
+ args = [
607
+ str(len(shape)),
608
+ size_array_var,
609
+ stride_array_var,
610
+ dtype_code,
611
+ device_type,
612
+ device_idx,
613
+ f"&{name}_handle",
614
+ ]
615
+
616
+ self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;")
617
+ self.wrapper_call.writeline(
618
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));"
619
+ )
620
+
621
+ return f"RAIIAtenTensorHandle {name}({name}_handle);"
622
+
623
+ def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool):
624
+ assert old.get_dtype() == new.get_dtype()
625
+ old_name = old.get_name()
626
+ new_name = new.get_name()
627
+ del_line = ";"
628
+ if old_name not in V.graph.get_output_names() and delete_old:
629
+ del_line = f"; {self.make_buffer_free(old)}"
630
+
631
+ if old.get_size() == new.get_size() and old.get_stride() == new.get_stride():
632
+ if old_name in self.stack_allocated_buffers:
633
+ self.stack_allocated_buffers[new_name] = new
634
+ return self.codegen_exact_buffer_reuse(old_name, new_name, del_line)
635
+
636
+ reinterpret_view = self.codegen_reinterpret_view(
637
+ old, new.get_size(), new.get_stride(), 0, self.wrapper_call.writeline
638
+ )
639
+ if reinterpret_view in self.stack_allocated_buffers:
640
+ self.stack_allocated_buffers[new_name] = new
641
+ # The only way to get into this case is via an exact buffer reuse, since all
642
+ # other options result in a new tensor handle.
643
+ return self.codegen_exact_buffer_reuse(old_name, new_name, del_line)
644
+ return f"{self.declare}{new_name} = {reinterpret_view}{del_line} // reuse"
645
+
646
+ def _assert_safe_to_use_borrow_arrayref_tensor_as_tensor(self):
647
+ # Borrowing arguments to shim functions is only safe because we know
648
+ # that the arguments can't be stack-allocated. Otherwise, to be sure
649
+ # we can't return a dangling pointer, we need to either 1) be
650
+ # certain that the shim function cannot return an alias of a
651
+ # borrowed argument, or 2) be certain that the returned Tensor from
652
+ # the shim function cannot escape.
653
+ assert self.is_safe_to_use_borrow_arrayref_tensor_as_tensor(), (
654
+ "borrowing arguments to shim functions is unsafe with "
655
+ "stack allocation on! (see comment above this assertion)"
656
+ )
657
+
658
+ def is_safe_to_use_borrow_arrayref_tensor_as_tensor(self):
659
+ return not self.allow_stack_allocation and not self.stack_allocated_buffers
660
+
661
+ def generate_c_shim_extern_kernel_call(
662
+ self, kernel: str, args: list[str], device: str, **_
663
+ ) -> None:
664
+ # In the abi_compatible mode, we call fallback aten ops through a C shim layer
665
+ # Setting self.allow_stack_allocation to False because the exchange between
666
+ # ArrayRefTensor and at::Tensor is still fragile.
667
+ self.allow_stack_allocation = False
668
+
669
+ wrapped_args = []
670
+ for arg in args:
671
+ # We only really *need* borrow_arrayref_tensor_as_tensor for
672
+ # ArrayRefTensors. The code flowing into here uses `0` for nullptr, which
673
+ # borrow_arrayref_tensor_as_tensor would blindly coerce to int, so just
674
+ # avoid wrapping integers. Name matching is to find tensor is hacky, but
675
+ # fixing all the ArrayRefTensor issues is not a priority for now.
676
+ if isinstance(arg, str) and arg.startswith(
677
+ ("buf", "arg", "wrap_with_raii_handle_if_needed")
678
+ ):
679
+ self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()
680
+ arg = f"borrow_arrayref_tensor_as_tensor({arg})"
681
+ wrapped_args.append(arg)
682
+
683
+ super().generate_c_shim_extern_kernel_call(
684
+ kernel, wrapped_args, device, debug_args=args
685
+ )
686
+
687
+ def generate_scatter_fallback(
688
+ self,
689
+ output,
690
+ inputs,
691
+ cpp_kernel_name,
692
+ python_kernel_name,
693
+ src_is_tensor,
694
+ reduce,
695
+ kwargs,
696
+ ):
697
+ # No stack allocation when there is a fallback op
698
+ self.allow_stack_allocation = False
699
+
700
+ # call the ABI shim function instead of the ATen one
701
+ cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
702
+ # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
703
+ cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
704
+ self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()
705
+ inputs_wrapped = [
706
+ (f"borrow_arrayref_tensor_as_tensor({x})" if isinstance(x, str) else str(x))
707
+ for x in inputs
708
+ ]
709
+ line = f"{cpp_kernel_name}(borrow_arrayref_tensor_as_tensor({output}), {','.join(inputs_wrapped)}"
710
+
711
+ if python_kernel_name.startswith("aten.scatter_reduce"):
712
+ line += f", {','.join(kwargs)}"
713
+ else:
714
+ if src_is_tensor:
715
+ if reduce:
716
+ line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
717
+ else:
718
+ assert reduce is None, (
719
+ "Expect reduce to be None for aten.scatter_ with scalar src"
720
+ )
721
+ line += ");"
722
+ self.writeline(line)
723
+
724
+ def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
725
+ # No stack allocation when there is a fallback op
726
+ self.allow_stack_allocation = False
727
+
728
+ self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()
729
+ # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version
730
+ # See the comment in codegen_reinterpret_view about why having something like
731
+ # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the corresponding
732
+ # tensor prematurely deallocated, thus the temporary array trick here.
733
+ indices_str = self._generate_temporary_array_pointer(
734
+ "AtenTensorHandle",
735
+ [f"borrow_arrayref_tensor_as_tensor({i})" for i in indices],
736
+ )
737
+ args = [
738
+ f"borrow_arrayref_tensor_as_tensor({x})",
739
+ indices_str,
740
+ str(len(indices)),
741
+ f"borrow_arrayref_tensor_as_tensor({values})",
742
+ accumulate,
743
+ ]
744
+ args.insert(
745
+ 0, f"borrow_arrayref_tensor_as_tensor({x})"
746
+ ) # set x as the output tensor, this fallback mutates x.
747
+ self.writeline(self.wrap_kernel_call(kernel, args))
748
+
749
+ def generate_fallback_kernel_with_runtime_lookup(
750
+ self,
751
+ buf_name: str,
752
+ python_kernel_name: str,
753
+ get_args: Callable[[], Sequence[str]],
754
+ op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
755
+ raw_args: Sequence[Any],
756
+ outputs: Sequence[ir.Buffer],
757
+ ) -> None:
758
+ # No stack allocation when there is a fallback op
759
+ self.allow_stack_allocation = False
760
+ super().generate_fallback_kernel_with_runtime_lookup(
761
+ buf_name, python_kernel_name, get_args, op_overload, raw_args, outputs
762
+ )
763
+
764
+ def codegen_device_copy(self, src, dst, non_blocking: bool):
765
+ # aoti_torch_tensor_copy_ takes AtenTensorHandle as input,
766
+ # while stack-allocation results in ArrayRefTensor
767
+ # so disable stack allocation here
768
+ self.allow_stack_allocation = False
769
+ self.writeline(
770
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));"
771
+ )
772
+
773
+ def codegen_reinterpret_view(
774
+ self,
775
+ data,
776
+ size,
777
+ stride,
778
+ offset,
779
+ writeline: Callable[..., None],
780
+ dtype=None,
781
+ ) -> str:
782
+ """Returns a newly-created, temporary RAII tensor handle containing the
783
+ reinterpreted tensor data. Callers of this function are responsible for saving
784
+ the handle if persistent access is needed."""
785
+ dim = str(len(size))
786
+
787
+ def create_reinterpret_call() -> str:
788
+ args = [
789
+ f"{data.get_name()}",
790
+ dim,
791
+ self.codegen_int_array_var(
792
+ self.codegen_shape_tuple(size),
793
+ writeline,
794
+ known_statically=self.is_statically_known_list_of_ints(size),
795
+ graph=self.get_codegened_graph(),
796
+ ),
797
+ self.codegen_int_array_var(
798
+ self.codegen_shape_tuple(stride),
799
+ writeline,
800
+ known_statically=self.is_statically_known_list_of_ints(stride),
801
+ graph=self.get_codegened_graph(),
802
+ ),
803
+ offset,
804
+ ]
805
+ return f"wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper({', '.join(args)}))"
806
+
807
+ def create_new_tensor_handle() -> tuple[str, list[str]]:
808
+ # Calling reset() on ArrayRefTensor does nothing, since the array is
809
+ # const-allocated on the stack. Thus, it's safe to return a reference to
810
+ # the original array.
811
+ if (name := data.get_name()) in self.stack_allocated_buffers:
812
+ return name, []
813
+
814
+ tmp_AtenTensorHandle = f"tmp_{name}_{next(self.tmp_tensor_id)}"
815
+ tmp_call_strs = [
816
+ f"AtenTensorHandle {tmp_AtenTensorHandle};",
817
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));",
818
+ ]
819
+ return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs
820
+
821
+ if (
822
+ size == data.layout.size
823
+ and stride == data.layout.stride
824
+ and offset == data.layout.offset
825
+ and (dtype is None or dtype == data.dtype)
826
+ ):
827
+ final_tensor_str, call_strs = create_new_tensor_handle()
828
+ for line in call_strs:
829
+ writeline(line)
830
+ return final_tensor_str
831
+
832
+ return super().codegen_reinterpret_view(
833
+ data, size, stride, offset, writeline, dtype
834
+ )
835
+
836
+ def val_to_arg_str(self, val, type_=None) -> str:
837
+ if (
838
+ val is not None
839
+ and isinstance(type_, torch.OptionalType)
840
+ and isinstance(type_.getElementType(), torch.TensorType)
841
+ ):
842
+ # Handle optional tensors as a special case, as in the parent class.
843
+ base_handle = self.val_to_arg_str(val, torch.TensorType)
844
+ if config.aot_inductor.use_minimal_arrayref_interface:
845
+ if self.is_safe_to_use_borrow_arrayref_tensor_as_tensor():
846
+ base_handle = f"borrow_arrayref_tensor_as_tensor({base_handle})"
847
+ else:
848
+ base_handle = f"copy_arrayref_tensor_to_tensor({base_handle})"
849
+ return f"&temporary_reference({base_handle}.get())"
850
+
851
+ return super().val_to_arg_str(val, type_)
852
+
853
+ def codegen_tensor_item(
854
+ self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None
855
+ ):
856
+ dtype_str = str(dtype).split(".")[-1]
857
+ writer = indented_buffer or self
858
+
859
+ if dtype == torch.float16 or dtype == torch.bfloat16:
860
+ scalar_tmp = f"{scalar}_tmp"
861
+ writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};")
862
+
863
+ # We know that item_ doesn't alias the input, so borrowing should be safe.
864
+ tensor = f"borrow_arrayref_tensor_as_tensor({tensor})"
865
+
866
+ writer.writeline(
867
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));"
868
+ )
869
+ writer.writeline(f"float {scalar} = float({scalar_tmp});")
870
+ else:
871
+ writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};")
872
+
873
+ # We know that item_ doesn't alias the input, so borrowing should be safe.
874
+ tensor = f"borrow_arrayref_tensor_as_tensor({tensor})"
875
+
876
+ writer.writeline(
877
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));"
878
+ )
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_gpu.py ADDED
@@ -0,0 +1,717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from __future__ import annotations
3
+
4
+ import dataclasses
5
+ import re
6
+ from itertools import count, zip_longest
7
+ from typing import Any, Optional, Union
8
+ from typing_extensions import Self
9
+
10
+ import sympy
11
+
12
+ import torch
13
+ from torch import dtype as torch_dtype
14
+ from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
15
+ from torch._inductor.runtime.runtime_utils import dynamo_timed
16
+
17
+ from .. import config
18
+ from ..codecache import CudaKernelParamCache
19
+ from ..ir import (
20
+ GraphPartitionSignature,
21
+ TensorBox,
22
+ TMADescriptorExperimental,
23
+ TMADescriptorStable,
24
+ )
25
+ from ..utils import cache_on_self, get_gpu_type, GPU_ALIGN_BYTES, IndentedBuffer
26
+ from ..virtualized import V
27
+ from .aoti_hipify_utils import maybe_hipify_code_wrapper
28
+ from .common import get_device_op_overrides, TritonScratchWorkspace
29
+ from .cpp_utils import cexpr
30
+ from .cpp_wrapper_cpu import CppWrapperCpu
31
+ from .multi_kernel import MultiKernelCall
32
+ from .triton_utils import should_unwrap_unspec_arg
33
+ from .wrapper import PythonWrapperCodegen, SymbolicCallArg
34
+
35
+
36
+ _cpp_string_literal_escapes = {
37
+ "\\": "\\\\",
38
+ '"': '\\"',
39
+ "\n": "\\n",
40
+ "\t": "\\t",
41
+ "\r": "\\r",
42
+ }
43
+ _cpp_string_literal_pattern = re.compile(r'["\\\n\t\r]')
44
+
45
+
46
+ def cpp_string_literal(s: str) -> str:
47
+ escaped = _cpp_string_literal_pattern.sub(
48
+ lambda match: _cpp_string_literal_escapes[match.group(0)], s
49
+ )
50
+ return f'"{escaped}"'
51
+
52
+
53
+ @dataclasses.dataclass
54
+ class DeferredTritonCallWrapper:
55
+ """
56
+ When using cpp wrapper, GPU kernel load and launch needs to wait for Triton kernels
57
+ to be tuned and stored as cubin files, so use a deferred generating the final wrapper around
58
+ the triton kernel until right before the prefix is written.
59
+ """
60
+
61
+ wrapper_name: str
62
+ kernel_name: str
63
+ kernel_name_to_body: dict[str, str]
64
+ arg_types: list[Any]
65
+
66
+ def generate(self, wrapper: CppWrapperGpu):
67
+ """
68
+ Generate the GPU kernel definition, as well as load and launch code.
69
+ """
70
+ prefix = wrapper.prefix
71
+ if self.kernel_name.startswith("multi_kernel_"):
72
+ # MultiKernel will select one kernel after running the autotune block
73
+ self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
74
+ params = CudaKernelParamCache.get(self.kernel_name)
75
+ assert params, f"CudaKernelParamCache not populated for {self.kernel_name}"
76
+ def_args = params["def_args"]
77
+ arg_types = self.arg_types
78
+ inductor_meta = params["inductor_meta"]
79
+
80
+ if "extra_launcher_args" in inductor_meta and len(def_args) > len(arg_types):
81
+ # extra_launcher_args should already be in def_args
82
+ assert len(def_args) == len(arg_types) - len(
83
+ inductor_meta["extra_launcher_args"]
84
+ )
85
+ arg_types = arg_types + [SymbolicCallArg] * len(
86
+ inductor_meta["extra_launcher_args"]
87
+ )
88
+
89
+ if not V.graph.aot_mode:
90
+ prefix.writeline(
91
+ maybe_hipify_code_wrapper(
92
+ f"static {wrapper.device_codegen.cpp_kernel_type()} {self.kernel_name} = nullptr;"
93
+ )
94
+ )
95
+ kernel_var_name = self.kernel_name
96
+ else:
97
+ kernel_var_name = f"kernels_.{self.kernel_name}"
98
+
99
+ # tensors can be RAIIAtenTensorHandle or ConstantHandle, so make them template types
100
+ template_types = [
101
+ f"typename {name}_type_"
102
+ for name, arg_type in zip(def_args, arg_types)
103
+ if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg))
104
+ ]
105
+ if V.graph.aot_mode:
106
+ template_types.append("typename kernels_type_")
107
+ if template_types:
108
+ prefix.writeline(f"template <{', '.join(template_types)}>")
109
+ prefix.writeline(f"static inline void {self.wrapper_name}(")
110
+ with prefix.indent():
111
+ assert len(def_args) == len(arg_types), (def_args, arg_types)
112
+ for name, arg_type in zip(def_args, arg_types):
113
+ if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg)):
114
+ prefix.writeline(f"const {name}_type_& {name},")
115
+ elif issubclass(arg_type, (SymbolicCallArg, sympy.Expr, int)):
116
+ prefix.writeline(f"int64_t {name},")
117
+ elif arg_type is float:
118
+ prefix.writeline(f"float {name},")
119
+ elif arg_type is bool:
120
+ prefix.writeline(f"bool {name},")
121
+ else:
122
+ raise ValueError(f"Unexpected arg type {arg_type}")
123
+ prefix.writeline("int32_t device_idx_,")
124
+ prefix.writeline(
125
+ maybe_hipify_code_wrapper(
126
+ f"{wrapper.device_codegen.cpp_stream_type()} stream_,"
127
+ )
128
+ )
129
+ if V.graph.aot_mode:
130
+ prefix.writeline("kernels_type_& kernels_,")
131
+ prefix.writeline(
132
+ "const std::optional<std::string>& cubin_dir_ = std::nullopt"
133
+ )
134
+ prefix.writeline("){")
135
+ with prefix.indent():
136
+ if V.graph.aot_mode:
137
+ # Emit the original Triton kernel for debugging purposes
138
+ prefix.writeline("/*")
139
+ prefix.splice(self.kernel_name_to_body[self.kernel_name])
140
+ prefix.writeline("*/")
141
+ self.generate_grid(prefix, inductor_meta, params)
142
+ self.generate_load_kernel(prefix, kernel_var_name, params)
143
+ self.generate_launch_kernel(prefix, wrapper, kernel_var_name, params)
144
+ prefix.writeline("}")
145
+
146
+ if not config.aot_inductor.embed_kernel_binary:
147
+ # Ensure the cubin file is included in the package
148
+ V.graph.wrapper_code.additional_files.append(
149
+ params[get_cpp_wrapper_cubin_path_name()]
150
+ )
151
+
152
+ def generate_grid(
153
+ self,
154
+ prefix: IndentedBuffer,
155
+ inductor_meta: dict[str, Any],
156
+ params: dict[str, Any],
157
+ ):
158
+ from ..runtime.triton_heuristics import GridExpr
159
+
160
+ grid = GridExpr.from_meta(inductor_meta, params["config"], mode="cpp")
161
+ for line in grid.prefix:
162
+ prefix.writeline(line)
163
+ prefix.splice(
164
+ f"""\
165
+ uint32_t grid_0 = {grid.x_grid};
166
+ uint32_t grid_1 = {grid.y_grid};
167
+ uint32_t grid_2 = {grid.z_grid};
168
+ """
169
+ )
170
+ prefix.writeline("if (grid_0 == 0 || grid_1 == 0 || grid_2 == 0) return;")
171
+
172
+ def generate_load_kernel(self, prefix, kernel_var_name, params):
173
+ prefix.writeline(f"if ({kernel_var_name} == nullptr) {{")
174
+ with prefix.indent():
175
+ embed_kernel_args = [f"__{params['inductor_meta']['kernel_name']}_start"]
176
+ if torch.xpu.is_available():
177
+ # XPU needs the end address of the kernel to calculate the size of the kernel binary.
178
+ embed_kernel_args.append(
179
+ f"__{params['inductor_meta']['kernel_name']}_end"
180
+ )
181
+
182
+ load_kernel_args = (
183
+ [
184
+ *embed_kernel_args,
185
+ cpp_string_literal(params["mangled_name"]),
186
+ str(params["shared_mem"]),
187
+ ]
188
+ if V.graph.aot_mode and config.aot_inductor.embed_kernel_binary
189
+ else [
190
+ cpp_string_literal(params[get_cpp_wrapper_cubin_path_name()]),
191
+ cpp_string_literal(params["mangled_name"]),
192
+ str(params["shared_mem"]),
193
+ "cubin_dir_",
194
+ ]
195
+ )
196
+ prefix.writeline(
197
+ f"{kernel_var_name} = loadKernel({', '.join(load_kernel_args)}); "
198
+ )
199
+ prefix.writeline("}")
200
+
201
+ def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params):
202
+ triton_meta = params["triton_meta"]
203
+ assert len(self.arg_types) == len(params["def_args"]), (
204
+ self.arg_types,
205
+ params["def_args"],
206
+ )
207
+ arg_type_loookup = dict(zip(params["def_args"], self.arg_types))
208
+ # difference between Python and C++ wrapper: C++ wrapper strips out equal_to_1 constants
209
+ call_args = [
210
+ name for name in params["call_args"] if name not in triton_meta["constants"]
211
+ ]
212
+ arg_types = [arg_type_loookup[name] for name in call_args]
213
+ arg_signatures = [triton_meta["signature"][name] for name in call_args]
214
+ call_args_str = wrapper.generate_args_decl(
215
+ prefix,
216
+ call_args,
217
+ arg_types,
218
+ arg_signatures,
219
+ workspace_size=params.get("global_scratch") or 0,
220
+ )
221
+ prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};")
222
+ launch_kernel_args = [
223
+ kernel_var_name,
224
+ "grid_0",
225
+ "grid_1",
226
+ "grid_2",
227
+ str(params["num_warps"]),
228
+ str(params["shared_mem"]),
229
+ "kernel_args_",
230
+ "stream_",
231
+ ]
232
+ prefix.writeline(f"launchKernel({', '.join(launch_kernel_args)});")
233
+
234
+
235
+ class CppWrapperGpu(CppWrapperCpu):
236
+ """
237
+ Generates cpp wrapper for running on GPU and calls CUDA kernels
238
+ """
239
+
240
+ def __init__(self) -> None:
241
+ self.device = get_gpu_type()
242
+ self.device_codegen = get_device_op_overrides(self.device)
243
+ super().__init__()
244
+ self.grid_id = count()
245
+ self._kernel_name_to_body: dict[str, str] = {}
246
+ self._triton_call_wrappers: dict[str, DeferredTritonCallWrapper] = {}
247
+ self.autotune_input_prefix = "_REAL_AUTOTUNE_INPUT"
248
+
249
+ @staticmethod
250
+ def create(
251
+ is_subgraph: bool,
252
+ subgraph_name: Optional[str],
253
+ parent_wrapper: Optional[PythonWrapperCodegen],
254
+ partition_signatures: Optional[GraphPartitionSignature] = None,
255
+ ):
256
+ # TODO - support subgraph codegen by lifting functions. Check the
257
+ # comment at CppWrapperCpu `codegen_subgraph` function.
258
+ return CppWrapperGpu()
259
+
260
+ def write_header(self):
261
+ if V.graph.is_const_graph:
262
+ # We do not write header for constant graph, it will be written by main module.
263
+ return
264
+
265
+ super().write_header()
266
+ self.header.splice(
267
+ maybe_hipify_code_wrapper(self.device_codegen.kernel_driver())
268
+ )
269
+
270
+ @cache_on_self
271
+ def write_tma_descriptor_helpers_once(self):
272
+ self.header.splice(self.device_codegen.tma_descriptor_helpers())
273
+
274
+ def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str:
275
+ name = f"stream{device_idx}"
276
+ self.writeline(
277
+ maybe_hipify_code_wrapper(
278
+ f"{self.device_codegen.cpp_stream_type()} {name};"
279
+ )
280
+ )
281
+ self.writeline(
282
+ f"AOTI_TORCH_ERROR_CODE_CHECK({self.device_codegen.aoti_get_stream()}({device_idx}, (void**)&{name}));"
283
+ )
284
+ return name
285
+
286
+ def get_autotuning_input_name(self, idx):
287
+ return f"{self.autotune_input_prefix}_{idx}"
288
+
289
+ def codegen_inputs(self):
290
+ # See Note: [Input Alignment handling in Inductor]
291
+ #
292
+ # JIT Inductor does not guard on input alignment. It relies on copy_misaligned_inputs to
293
+ # copy misaligned inputs to aligned buffers. For AOTInductor, we need to do the same in cpp.
294
+
295
+ if config.is_fbcode():
296
+ # TODO: This is added because FC. Remove this once the newly added shim symbols,
297
+ # e.g. aoti_torch_clone_preserve_strides, have landed
298
+ return super().codegen_inputs()
299
+
300
+ if V.graph.aot_mode and V.graph.inputs_to_check:
301
+ for idx in V.graph.inputs_to_check:
302
+ input_name = V.graph.graph_input_names[idx]
303
+ assert input_name in V.graph.graph_inputs, (
304
+ f"{input_name} not found in graph inputs"
305
+ )
306
+ value = V.graph.graph_inputs[input_name]
307
+ assert isinstance(value, TensorBox), (
308
+ f"{input_name} is expected to be tensor but found as {type(value)}"
309
+ )
310
+ warn_msg = (
311
+ f"Input {idx} was compiled as {GPU_ALIGN_BYTES}-bytes aligned, "
312
+ "but it is not aligned at run time. Copying to an aligned tensor "
313
+ "to guarantee correctness, but expect a performance hit."
314
+ )
315
+ self.prefix.splice(
316
+ f"""
317
+ if ((long({input_name}.data_ptr()) & ({GPU_ALIGN_BYTES} -1)) != 0) {{
318
+ AOTI_TORCH_WARN("{warn_msg}");
319
+ AtenTensorHandle {input_name}_aligned;
320
+ aoti_torch_clone_preserve_strides({input_name}, &{input_name}_aligned);
321
+ {input_name} = std::move(RAIIAtenTensorHandle({input_name}_aligned));
322
+ }}
323
+ """
324
+ )
325
+
326
+ super().codegen_inputs()
327
+
328
+ def _define_kernel_helper(
329
+ self,
330
+ kernel_name: str,
331
+ kernel_body: str,
332
+ metadata: Optional[str] = None,
333
+ gpu: bool = True,
334
+ cpp_definition: Optional[str] = None,
335
+ ):
336
+ if gpu:
337
+ self._kernel_name_to_body[kernel_name] = kernel_body
338
+ if config.triton.autotune_at_compile_time:
339
+ # Call PythonWrapperCodegen to create the autotune code block
340
+ PythonWrapperCodegen._define_kernel_helper(
341
+ self, kernel_name, kernel_body, metadata, gpu, cpp_definition
342
+ )
343
+ else:
344
+ return CppWrapperCpu._define_kernel_helper(
345
+ self, kernel_name, kernel_body, metadata, gpu, cpp_definition
346
+ )
347
+
348
+ def generate(self, is_inference):
349
+ with dynamo_timed("CppWrapperGpu.generate", log_pt2_compile_event=True):
350
+ return super().generate(is_inference)
351
+
352
+ def finalize_prefix(self):
353
+ """Define the triton kernels now that autotuning is finished"""
354
+ old_prefix = self.prefix # new content should go at start of prefix
355
+
356
+ # Generating triton kernel callers can modify the prefix (cached dtypes),
357
+ # so do this before running finalize_prefix(), but put the generated code
358
+ # after the finalize_prefix() code.
359
+ self.prefix = IndentedBuffer()
360
+ for kernel in self._triton_call_wrappers.values():
361
+ self.prefix.writeline("\n")
362
+ kernel.generate(self)
363
+ triton_prefix = self.prefix
364
+
365
+ self.prefix = IndentedBuffer()
366
+ super().finalize_prefix()
367
+
368
+ self.prefix.splice(triton_prefix)
369
+
370
+ self.prefix.writeline("\n")
371
+ self.prefix.splice(old_prefix)
372
+
373
+ def generate_tma_descriptor(self, desc):
374
+ self.write_tma_descriptor_helpers_once()
375
+
376
+ if isinstance(desc, TMADescriptorExperimental):
377
+ self._generate_experimental_tma_descriptor(desc)
378
+ else:
379
+ assert isinstance(desc, TMADescriptorStable)
380
+ self._generate_stable_tma_descriptor(desc)
381
+
382
+ def _generate_experimental_tma_descriptor(self, desc):
383
+ # generate data pointer for the source tensor
384
+ source = self.generate_args_decl(
385
+ code=self,
386
+ call_args=[self.val_to_arg_str(desc.tensor)],
387
+ arg_types=[desc.tensor.get_dtype()],
388
+ arg_signatures=[None],
389
+ # these args are passed to initNDTMADescriptor, which is NOT a triton kernel
390
+ is_triton_kernel=False,
391
+ )
392
+
393
+ desc_name = desc.name
394
+ self.writeline(f"alignas(64) CUtensorMap {desc_name};")
395
+
396
+ # `source` is in the form of `&var_x`, where `var_x` is the data pointer
397
+ # (CUdeviceptr); we dereference `source` and cast to `void*` to pass to
398
+ # the data pointer of the source tensor to the helper function
399
+ # `init{1,2}DTMADescriptor`
400
+ ptr = f"reinterpret_cast<void*>(*({source}))"
401
+ dims = ", ".join(self.val_to_arg_str(dim) for dim in desc.dims)
402
+ block_dims = ", ".join(self.val_to_arg_str(dim) for dim in desc.block_dims)
403
+ element_size = self.val_to_arg_str(desc.element_size)
404
+ fn = f"init{desc.rank}DTMADescriptor"
405
+ args = f"&{desc_name}, {ptr}, {dims}, {block_dims}, {element_size}"
406
+ self.writeline(f"{fn}({args});")
407
+
408
+ def _generate_stable_tma_descriptor(self, desc):
409
+ source = self.generate_args_decl(
410
+ code=self,
411
+ call_args=[self.val_to_arg_str(desc.tensor)],
412
+ arg_types=[desc.tensor.get_dtype()],
413
+ arg_signatures=[None],
414
+ # these args are passed to initNDTMADescriptor, which is NOT a triton kernel
415
+ is_triton_kernel=False,
416
+ )
417
+
418
+ desc_name = desc.name
419
+ # Pack the relevant information into a StableTMADescriptor struct.
420
+ # See [Note: AOTI TMA Stable handling] for more details.
421
+ self.writeline(f"alignas(64) StableTMADescriptor {desc_name};")
422
+
423
+ def fill_array(name, values):
424
+ for i, val in enumerate(values):
425
+ self.writeline(f"{name}[{i}] = {val};")
426
+
427
+ ptr = f"reinterpret_cast<void*>(*({source}))"
428
+ rank = len(desc.tensor.get_size())
429
+
430
+ fill_array(f"{desc_name}.block_shape", desc.block_shape)
431
+ fill_array(f"{desc_name}.global_shape", desc.tensor.get_size())
432
+ fill_array(f"{desc_name}.strides", desc.tensor.get_stride())
433
+
434
+ element_size = self.val_to_arg_str(desc.tensor.get_dtype().itemsize)
435
+ fn = "initTMADescriptor"
436
+ args = ", ".join(
437
+ str(x)
438
+ for x in [
439
+ f"&{desc_name}.m",
440
+ ptr,
441
+ element_size,
442
+ rank,
443
+ f"{desc_name}.block_shape",
444
+ f"{desc_name}.global_shape",
445
+ f"{desc_name}.strides",
446
+ ]
447
+ )
448
+ self.writeline(f"{fn}({args});")
449
+
450
+ def generate_args_decl(
451
+ self,
452
+ code: Union[IndentedBuffer, Self],
453
+ call_args,
454
+ arg_types,
455
+ arg_signatures,
456
+ is_triton_kernel=True,
457
+ workspace_size=0,
458
+ ):
459
+ """
460
+ Generates any declarations of args to pass into a kernel call, and then returns the arg names.
461
+
462
+ In more detail:
463
+ * declarations: e.g. this function has a side effect of generating lines like `auto var_0 = ...;`
464
+ * returns: a string with the list of args, e.g. "var_0, var_1"
465
+
466
+ call_args: list of call arguments
467
+ arg_types: list of argument types
468
+ arg_signatures: list with signatures of all the args
469
+ is_triton_kernel: whether these are passed into a triton kernel or not. In particular,
470
+ calls to triton kernels will have an additional global scratch space
471
+ arg injected at the front of the arg list.
472
+ """
473
+ new_args: list[str] = []
474
+
475
+ # Add more cases for other types as needed
476
+ signature2dtype = {
477
+ "i32": "int32_t",
478
+ "i64": "int64_t",
479
+ "fp32": "float",
480
+ }
481
+
482
+ def signature_is_tma_desc(sig):
483
+ if not sig:
484
+ return False
485
+ if sig == "nvTmaDesc":
486
+ return True
487
+ if sig.startswith("tensordesc<"):
488
+ return True
489
+ return False
490
+
491
+ def process_tma_stable_arg(arg, arg_type, arg_signature, var_name):
492
+ # [Note: AOTI TMA Stable handling]
493
+ # For most args, a single arg passed to the python triton interface
494
+ # maps to a single arg in the cubin interface. However, for host-side
495
+ # TMA descriptors, a single python arg turns into 1 + 2 * N args in the
496
+ # cubin interface (where N is the rank).
497
+ #
498
+ # To do this: at TMA codegen time (for aoti), we generate a struct
499
+ # (StableTMADescriptor) containing the necessary information; and then
500
+ # when we call the function (i.e. here), we unpack the struct members.
501
+ code.writeline(f"auto {var_name} = {cexpr(arg)};")
502
+
503
+ result = []
504
+ result.append(f"&{var_name}.m")
505
+
506
+ # from https://github.com/triton-lang/triton/blob/16961b79bdac1b774b42d44e52fd55a266ec2866/third_party/nvidia/backend/driver.py#L111 # noqa: B950
507
+ match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", arg_signature)
508
+ assert match is not None
509
+ shape = match.group(2)
510
+ ndim = shape.count(",") + 1
511
+
512
+ for i in range(ndim):
513
+ result.append(f"&{var_name}.block_shape[{i}]")
514
+
515
+ for i in range(ndim):
516
+ result.append(f"&{var_name}.strides[{i}]")
517
+
518
+ return result
519
+
520
+ def process_args(arg, arg_type, arg_signature=None):
521
+ var_name = f"var_{next(self.arg_var_id)}"
522
+ # ignore tma descriptors, as host-side TMA descriptors need
523
+ # to be passed to the compiled Triton kernel by value
524
+ if isinstance(arg_type, UnwrapUnspecArg) and not signature_is_tma_desc(
525
+ arg_signature
526
+ ):
527
+ self.codegen_tensor_item(
528
+ arg_type.dtype,
529
+ arg,
530
+ var_name,
531
+ indented_buffer=code,
532
+ )
533
+ new_args.append(f"&{var_name}")
534
+ elif isinstance(arg_type, torch_dtype) and not signature_is_tma_desc(
535
+ arg_signature
536
+ ):
537
+ device_ptr_type = self.device_codegen.cpp_device_ptr()
538
+ code.writeline(
539
+ maybe_hipify_code_wrapper(
540
+ f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());"
541
+ )
542
+ )
543
+ new_args.append(f"&{var_name}")
544
+ elif arg_type in (sympy.Integer, int):
545
+ code.writeline(f"int {var_name} = {cexpr(arg)};")
546
+ new_args.append(f"&{var_name}")
547
+ elif arg_type in (sympy.Float, float):
548
+ code.writeline(f"float {var_name} = {cexpr(arg)};")
549
+ new_args.append(f"&{var_name}")
550
+ # For symbolic call arguments, examine the arg signatures from triton meta
551
+ # to explicitly cast to the right type
552
+ # Reason: `auto` can infer unexpected type against kernel input signature.
553
+ elif (
554
+ isinstance(arg_type, type(SymbolicCallArg))
555
+ and arg_signature is not None
556
+ and arg_signature in signature2dtype.keys()
557
+ ):
558
+ code.writeline(
559
+ f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};"
560
+ )
561
+ new_args.append(f"&{var_name}")
562
+ elif arg_signature and arg_signature.startswith("tensordesc<"):
563
+ new_args.extend(
564
+ process_tma_stable_arg(arg, arg_type, arg_signature, var_name)
565
+ )
566
+ else:
567
+ code.writeline(f"auto {var_name} = {cexpr(arg)};")
568
+ new_args.append(f"&{var_name}")
569
+
570
+ for arg, arg_type, arg_signature in zip_longest(
571
+ call_args, arg_types, arg_signatures
572
+ ):
573
+ process_args(arg, arg_type, arg_signature)
574
+
575
+ if (
576
+ is_triton_kernel
577
+ and (
578
+ global_scratch := self.device_codegen.cpp_global_scratch(
579
+ next(self.arg_var_id),
580
+ workspace=TritonScratchWorkspace(
581
+ size=workspace_size,
582
+ generate_dtype_str=(lambda: self.codegen_dtype(torch.uint8)),
583
+ ),
584
+ )
585
+ )
586
+ is not None
587
+ ):
588
+ global_scratch_def, global_scratch_var = global_scratch
589
+ code.writelines([maybe_hipify_code_wrapper(x) for x in global_scratch_def])
590
+ new_args.append(f"&{global_scratch_var}")
591
+
592
+ return ", ".join(new_args)
593
+
594
+ def _generate_kernel_call_helper(
595
+ self,
596
+ kernel_name: str,
597
+ call_args,
598
+ *,
599
+ device=None,
600
+ triton=True,
601
+ arg_types=None,
602
+ raw_keys=None,
603
+ raw_args=None,
604
+ triton_meta=None,
605
+ graph_name="",
606
+ original_fxnode_name=None,
607
+ ):
608
+ """
609
+ Override the default value of argument 'gpu' to True here.
610
+ generate_kernel_call can still be called with gpu=False because of
611
+ a mix of cpu kernels and gpu kernels.
612
+ """
613
+ device = device or V.graph.get_current_device_or_throw()
614
+ if device.type == "cpu":
615
+ # Even in CppWrapperGpu, we may see cpp kernels
616
+ return CppWrapperCpu._generate_kernel_call_helper(
617
+ self,
618
+ kernel_name,
619
+ call_args,
620
+ device=device,
621
+ triton=triton,
622
+ arg_types=arg_types,
623
+ raw_keys=raw_keys,
624
+ raw_args=raw_args,
625
+ triton_meta=triton_meta,
626
+ )
627
+
628
+ if (
629
+ triton
630
+ and config.triton.autotune_at_compile_time
631
+ and kernel_name not in self.kernel_autotune_names
632
+ ):
633
+ # Call PythonWrapperCodegen to create the autotune code block
634
+ PythonWrapperCodegen._generate_kernel_call_helper(
635
+ self,
636
+ kernel_name,
637
+ call_args,
638
+ device=device,
639
+ triton=triton,
640
+ arg_types=arg_types,
641
+ raw_keys=raw_keys,
642
+ raw_args=raw_args,
643
+ triton_meta=triton_meta,
644
+ original_fxnode_name=original_fxnode_name,
645
+ )
646
+
647
+ stream = (
648
+ "stream"
649
+ if V.graph.aot_mode
650
+ else self.write_get_raw_stream(device.index, graph_name)
651
+ )
652
+
653
+ if triton:
654
+ call_args, arg_types = self.prepare_triton_wrapper_args(
655
+ call_args, arg_types
656
+ )
657
+ wrapper_name = f"call_{kernel_name}"
658
+ if wrapper_name not in self._triton_call_wrappers:
659
+ self._triton_call_wrappers[wrapper_name] = DeferredTritonCallWrapper(
660
+ wrapper_name,
661
+ kernel_name,
662
+ self._kernel_name_to_body,
663
+ arg_types,
664
+ )
665
+ device_idx = "this->device_idx_" if V.graph.aot_mode else str(device.index)
666
+ call_args.append(device_idx)
667
+ call_args.append(stream)
668
+ if V.graph.aot_mode:
669
+ call_args.append("kernels")
670
+ call_args.append("this->cubin_dir_")
671
+ debug_printer_manager = V.graph.wrapper_code.debug_printer
672
+ debug_printer_manager.set_printer_args(
673
+ call_args[: len(arg_types)], kernel_name, arg_types, None
674
+ )
675
+ with debug_printer_manager:
676
+ self.writeline(f"{wrapper_name}({', '.join(call_args)});")
677
+ else:
678
+ casted = []
679
+ for arg_type, arg in zip(arg_types, call_args):
680
+ new_arg = arg
681
+ if arg_type.endswith("*") and arg != "nullptr":
682
+ new_arg = f"{arg}.data_ptr()"
683
+ casted.append(f"({arg_type}){cexpr(new_arg)}")
684
+ call_args_str = ", ".join(casted)
685
+ self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});")
686
+
687
+ @staticmethod
688
+ def prepare_triton_wrapper_args(
689
+ call_args: list[Any], arg_types: list[Any]
690
+ ) -> tuple[list[Any], list[Any]]:
691
+ assert len(call_args) == len(arg_types), (call_args, arg_types)
692
+ new_args = []
693
+ new_args_types = []
694
+ for arg, arg_type in zip(call_args, arg_types):
695
+ if isinstance(arg, str):
696
+ if isinstance(arg_type, torch_dtype) and should_unwrap_unspec_arg(arg):
697
+ # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
698
+ arg_type = UnwrapUnspecArg(dtype=arg_type)
699
+ new_args.append(arg)
700
+ elif isinstance(arg, bool):
701
+ new_args.append(str(arg).lower())
702
+ elif isinstance(arg, (int, float, SymbolicCallArg)):
703
+ new_args.append(str(arg))
704
+ else:
705
+ new_args.append(cexpr(V.graph.sizevars.simplify(arg)))
706
+ new_args_types.append(arg_type)
707
+ return new_args, new_args_types
708
+
709
+ def make_zero_buffer(self, name):
710
+ return f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get()));"
711
+
712
+
713
+ @dataclasses.dataclass
714
+ class UnwrapUnspecArg:
715
+ """Marker that we need to call .item() on the tensor"""
716
+
717
+ dtype: torch_dtype
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_mps.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional
2
+
3
+ import sympy
4
+
5
+ import torch
6
+
7
+ from ..ir import GraphPartitionSignature
8
+ from ..virtualized import V
9
+ from .cpp_wrapper_gpu import CppWrapperGpu
10
+ from .wrapper import PythonWrapperCodegen
11
+
12
+
13
+ class CppWrapperMps(CppWrapperGpu):
14
+ @staticmethod
15
+ def create(
16
+ is_subgraph: bool,
17
+ subgraph_name: Optional[str],
18
+ parent_wrapper: Optional[PythonWrapperCodegen],
19
+ partition_signatures: Optional[GraphPartitionSignature] = None,
20
+ ) -> "CppWrapperMps":
21
+ return CppWrapperMps()
22
+
23
+ def _generate_kernel_call_helper(
24
+ self,
25
+ kernel_name: str,
26
+ call_args: list[str],
27
+ arg_types: Optional[list[type]] = None,
28
+ **kwargs: dict[str, Any],
29
+ ) -> None:
30
+ """
31
+ Generates MPS kernel call code. It should look something like:
32
+ ```
33
+ auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
34
+ auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
35
+ mps_lib_0_func->runCommandBlock([&] {
36
+ mps_lib_0_func->startEncoding();
37
+ aoti_torch_mps_set_arg(mps_lib_0_func_handle, 0, buf0);
38
+ aoti_torch_mps_set_arg(mps_lib_0_func_handle, 1, arg0_1);
39
+ ...
40
+ mps_lib_0_func->dispatch(9);
41
+ });
42
+ ```
43
+ """
44
+ assert arg_types is not None
45
+
46
+ new_args = []
47
+ for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])):
48
+ if isinstance(arg_type, torch.dtype):
49
+ new_args.append(
50
+ f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});\n"
51
+ )
52
+ elif arg_type in (int, sympy.core.symbol.Symbol):
53
+ new_args.append(
54
+ f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});\n"
55
+ )
56
+ else:
57
+ raise NotImplementedError(
58
+ f"Unsupported arg type {arg_type} for arg {arg} for kernel {kernel_name}"
59
+ )
60
+
61
+ threads, group_size = call_args[-2], call_args[-1]
62
+ if threads is None:
63
+ raise NotImplementedError("No threads or group_size provided")
64
+ elif group_size is None:
65
+ new_args.append(f"{kernel_name}->dispatch({threads});\n")
66
+ else:
67
+ new_args.append(f"{kernel_name}->dispatch({threads}, {group_size});\n")
68
+
69
+ # debug printer related logic for cpp kernel type.
70
+ debug_printer_manager = V.graph.wrapper_code.debug_printer
71
+ debug_printer_manager.set_printer_args(
72
+ call_args[:-2],
73
+ kernel_name,
74
+ None,
75
+ None,
76
+ "cpp",
77
+ )
78
+ with debug_printer_manager:
79
+ self.writeline(self.wrap_kernel_call(kernel_name, new_args))
80
+
81
+ def wrap_kernel_call(self, name: str, call_args: list[str]) -> str:
82
+ lib_name = name[: -len("_func")]
83
+ calling_args = " ".join(call_args)
84
+ return f"""
85
+ auto {name} = {lib_name}.getKernelFunction("generated_kernel");
86
+ auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get());
87
+ {name}->runCommandBlock([&] {{
88
+ {name}->startEncoding();
89
+ {calling_args}
90
+ }});
91
+ """
92
+
93
+ @staticmethod
94
+ def get_device_include_path(device: str) -> str:
95
+ assert V.graph.aot_mode
96
+ return (
97
+ "#include <torch/csrc/inductor/aoti_include/mps.h>\n"
98
+ "#include <torch/csrc/inductor/aoti_torch/c/shim_mps.h>"
99
+ )
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpu_device_op_overrides.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from textwrap import dedent
4
+
5
+ from .common import DeviceOpOverrides, register_device_op_overrides
6
+
7
+
8
+ class CpuDeviceOpOverrides(DeviceOpOverrides):
9
+ def import_get_raw_stream_as(self, name: str) -> str:
10
+ return dedent(
11
+ """
12
+ def get_raw_stream(_):
13
+ return 0
14
+ """
15
+ )
16
+
17
+ def set_device(self, device_idx: int) -> str:
18
+ return "pass"
19
+
20
+ def synchronize(self) -> str:
21
+ return "pass"
22
+
23
+ def device_guard(self, device_idx: int) -> str:
24
+ return "pass"
25
+
26
+
27
+ register_device_op_overrides("cpu", CpuDeviceOpOverrides())
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__init__.py ADDED
File without changes
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import hashlib
3
+ import logging
4
+ from collections.abc import Sequence
5
+ from typing import cast
6
+
7
+ from torch._inductor.codegen.cuda.cutlass_python_evt import (
8
+ CutlassEVTCodegen,
9
+ MockCutlassHandler,
10
+ )
11
+ from torch._inductor.utils import Placeholder
12
+ from torch.utils._ordered_set import OrderedSet
13
+
14
+ from ...._dynamo.utils import counters
15
+ from ... import config
16
+ from ...codecache import code_hash, get_path
17
+ from ...ir import Buffer, ComputedBuffer, CUDATemplateBuffer, Pointwise
18
+ from ...scheduler import (
19
+ BaseSchedulerNode,
20
+ BaseScheduling,
21
+ FusedSchedulerNode,
22
+ SchedulerNode,
23
+ WhyNoFuse,
24
+ )
25
+ from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product
26
+ from ...virtualized import V
27
+ from ..common import BackendFeature, IndentedBuffer
28
+
29
+
30
+ log = logging.getLogger(__name__)
31
+
32
+
33
+ class WhyNoFuseNames(WhyNoFuse):
34
+ def __init__(self, name1: str, name2: str) -> None:
35
+ self.name1 = name1
36
+ self.name2 = name2
37
+
38
+
39
+ class CUDACPPScheduling(BaseScheduling):
40
+ """
41
+ Partial Scheduling implementation for CUDA C++ Kernels.
42
+ This class is intended to be used in combination with TritonScheduling,
43
+ and delegated to by CUDACombinedScheduling.
44
+
45
+ It handles fusion decisions and CUDA C++ specific template code generation.
46
+ """
47
+
48
+ @classmethod
49
+ def get_backend_features(cls, device) -> OrderedSet[BackendFeature]:
50
+ return OrderedSet()
51
+
52
+ def group_fn(self, sizes):
53
+ return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes)
54
+
55
+ @staticmethod
56
+ def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool:
57
+ return isinstance(node, SchedulerNode) and isinstance(
58
+ node.node, CUDATemplateBuffer
59
+ )
60
+
61
+ def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool:
62
+ return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template(node)
63
+
64
+ def can_fuse_vertical(
65
+ self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
66
+ ) -> bool:
67
+ if self.is_cuda_cpp_template(node1) and isinstance(node2, BaseSchedulerNode):
68
+ assert node1.node, "node1.node should not be None"
69
+ return self._can_fuse_epilogue_impl(
70
+ cast(CUDATemplateBuffer, node1.node),
71
+ [],
72
+ node2, # type: ignore[arg-type]
73
+ )
74
+ elif self.is_cuda_cpp_fused_template(node1) and isinstance(
75
+ node2, BaseSchedulerNode
76
+ ):
77
+ assert node1.node, "node1.node should not be None"
78
+ assert node2.node, "node2.node should not be None"
79
+ fnode1 = cast(FusedSchedulerNode, node1)
80
+ return self._can_fuse_epilogue_impl(
81
+ fnode1.get_template_node(), # type: ignore[arg-type]
82
+ self._unwrap_epilogue_nodes(fnode1),
83
+ node2, # type: ignore[arg-type]
84
+ )
85
+
86
+ return False
87
+
88
+ def define_kernel(self, src_code: str, node_schedule) -> str:
89
+ wrapper = V.graph.wrapper_code
90
+ if src_code in wrapper.src_to_kernel:
91
+ kernel_name = wrapper.src_to_kernel[src_code]
92
+ else:
93
+ fused_name = (
94
+ get_fused_kernel_name(node_schedule, config.triton.descriptive_names)
95
+ if config.triton.descriptive_names
96
+ else ""
97
+ )
98
+
99
+ # use the original src_code as the key
100
+ kernel_hash = hashlib.sha256(src_code.encode("utf-8")).hexdigest()[:8]
101
+ if fused_name == "fused":
102
+ # no EVT kernel, use the original kernel name
103
+ kernel_name = f"cutlass_{kernel_hash}"
104
+ else:
105
+ kernel_name = f"cutlass_{fused_name}_{kernel_hash}"
106
+ wrapper.src_to_kernel[src_code] = kernel_name
107
+ src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_name)
108
+
109
+ _, _, kernel_path = get_path(code_hash(src_code), "py")
110
+
111
+ compile_wrapper = IndentedBuffer()
112
+ compile_wrapper.writeline("async_compile.cuda(r'''")
113
+ compile_wrapper.splice(src_code, strip=True)
114
+ compile_wrapper.writeline(
115
+ f"''', 'so', aot_compile={str(V.graph.aot_mode)})"
116
+ )
117
+
118
+ metadata_comment = f"# kernel path: {kernel_path}"
119
+ origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
120
+ metadata_comment += "\n" + origins + "\n" + detailed_origins
121
+ wrapper.define_kernel(
122
+ kernel_name, compile_wrapper.getvalue(), metadata_comment
123
+ )
124
+ return kernel_name
125
+
126
+ def codegen_template(
127
+ self,
128
+ template_node: BaseSchedulerNode,
129
+ epilogue_nodes: Sequence[BaseSchedulerNode],
130
+ prologue_nodes: Sequence[BaseSchedulerNode],
131
+ ):
132
+ """
133
+ Codegen a CUDA template, possibly with fused epilogues
134
+ """
135
+ counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes)
136
+ assert self.is_cuda_cpp_template(template_node), (
137
+ "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
138
+ )
139
+ template_node = cast(SchedulerNode, template_node)
140
+ _, (_numel, rnumel) = template_node.group
141
+ assert rnumel == 1
142
+ ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node)
143
+ epilogue_ir_nodes: list[Buffer] = [n.node for n in epilogue_nodes] # type: ignore[misc]
144
+ assert all(isinstance(n, ComputedBuffer) for n in epilogue_ir_nodes), (
145
+ "Epilogue nodes must all be instances of ir.ComputedBuffer"
146
+ )
147
+ kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_nodes)
148
+
149
+ with kernel:
150
+ for node in [template_node, *epilogue_nodes]:
151
+ node.mark_run()
152
+
153
+ # typically there is a codegen pass which runs after mark_run
154
+ # for this kernel we've already generated the C++ code, but we still
155
+ # need to let the kernel know about loads/stores that occur in the fused
156
+ # kernel for memory planning to properly optimize allocations
157
+ ctb.emulate_store_fn()
158
+ for node in epilogue_ir_nodes:
159
+ with V.set_ops_handler(MockCutlassHandler(V.get_ops_handler())):
160
+ assert isinstance(
161
+ node, ComputedBuffer
162
+ ) # Not sure why we need to do this again
163
+ node.get_store_function()(CutlassEVTCodegen.get_index_vars(node))
164
+
165
+ with V.set_kernel_handler(kernel):
166
+ src_code = render()
167
+ node_schedule = [template_node, *epilogue_nodes]
168
+ kernel_name = self.define_kernel(src_code, node_schedule)
169
+
170
+ # debug printing values of intermediate tensors
171
+ _, call_args, arg_signatures, _ = kernel.args.python_argdefs()
172
+ debug_printer_manager = V.graph.wrapper_code.debug_printer
173
+ debug_printer_manager.set_printer_args(
174
+ call_args, kernel_name, arg_signatures, kernel
175
+ )
176
+ with debug_printer_manager:
177
+ kernel.call_kernel(kernel_name, ctb)
178
+
179
+ V.graph.removed_buffers |= kernel.removed_buffers
180
+ self.free_buffers_in_scheduler()
181
+
182
+ @staticmethod
183
+ def _unwrap_epilogue_nodes(
184
+ fused_node: FusedSchedulerNode,
185
+ ) -> list[BaseSchedulerNode]:
186
+ nodes = fused_node.get_nodes()
187
+ template_node = fused_node.get_template_node()
188
+ assert all(n.node is not None for n in nodes), (
189
+ "All epilogue nodes should have an IRNode"
190
+ )
191
+ return cast(
192
+ list[BaseSchedulerNode], [n for n in nodes if n.node is not template_node]
193
+ )
194
+
195
+ def _can_fuse_epilogue_impl(
196
+ self,
197
+ cuda_template_buffer: CUDATemplateBuffer,
198
+ existing_epilogue_nodes: list[BaseSchedulerNode],
199
+ node_to_fuse: BaseSchedulerNode,
200
+ ) -> bool:
201
+ """
202
+ Check if the given node can be fused with the epilogue. At the moment, Kernels
203
+ support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes.
204
+
205
+ Args:
206
+ cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer
207
+ existing_epilogue_nodes : List[SchedulerNode]: The list of already fused epilogue nodes.
208
+ node_to_fuse: The SchedulerNode node to be checked if it can be fused with the epilogue.
209
+ Returns:
210
+ - bool: True if the given node can be fused with the epilogue, False otherwise.
211
+
212
+ """
213
+ why = WhyNoFuseNames(cuda_template_buffer.get_name(), node_to_fuse.get_name())
214
+
215
+ scheduler_nodes_to_fuse = node_to_fuse.get_nodes()
216
+
217
+ assert isinstance(cuda_template_buffer, CUDATemplateBuffer)
218
+
219
+ # Checks on constituent nodes
220
+ for s_node in scheduler_nodes_to_fuse:
221
+ node = s_node.node
222
+
223
+ if not isinstance(node, ComputedBuffer):
224
+ why(f"{node} is not a ComputedBuffer")
225
+ return False
226
+ elif not isinstance(node.data, Pointwise):
227
+ why(f"{node} is not a Pointwise op")
228
+ return False
229
+ elif not node.get_computed_buffer_name(): # type: ignore[attr-defined]
230
+ why(f"{node} does not have a computed buffer name")
231
+ return False
232
+
233
+ name = node.get_computed_buffer_name() # type: ignore[attr-defined]
234
+ # dtype can differ, and strides can differ as long as they are broadcastable
235
+ if node.get_size() != cuda_template_buffer.get_size():
236
+ why(
237
+ f"{name}'s size: {node.get_size()} differs from {cuda_template_buffer.get_name()}'s \
238
+ size: {cuda_template_buffer.get_size()}"
239
+ )
240
+ return False
241
+
242
+ assert len(
243
+ existing_epilogue_nodes
244
+ ) or cuda_template_buffer.get_name() in OrderedSet(
245
+ [rd.name for rd in node_to_fuse.read_writes.reads]
246
+ ), "First epilogue node must read from cuda template buffer"
247
+
248
+ if node_to_fuse.has_aliasing_or_mutation():
249
+ why(f"{node_to_fuse.get_name()} has aliasing or mutation")
250
+ return False
251
+ elif node_to_fuse.is_reduction():
252
+ why(
253
+ f"{node_to_fuse.get_name()} is a reduction which is not yet supported by EVT"
254
+ )
255
+ return False
256
+ elif (
257
+ not config.cuda.cutlass_epilogue_fusion_enabled
258
+ or not config.epilogue_fusion
259
+ ):
260
+ why("cutlass epilogue fusion is not enabled")
261
+ return False
262
+ elif not cuda_template_buffer.supports_epilogue_fusion:
263
+ why("epilogue fusion is only supported for TMA-enabled gemm ops")
264
+ return False
265
+
266
+ try:
267
+ from torch._inductor.codegen.cuda.cutlass_python_evt import (
268
+ CutlassEVTCodegen,
269
+ )
270
+
271
+ CutlassEVTCodegen.ir_to_evt_python_code(
272
+ cuda_template_buffer.get_name(),
273
+ existing_epilogue_nodes + list(node_to_fuse.get_nodes()),
274
+ OrderedSet(),
275
+ )
276
+
277
+ except NotImplementedError as e:
278
+ not_implemented_op = str(e)
279
+ if not_implemented_op.startswith("_op_"):
280
+ not_implemented_op = not_implemented_op[4:]
281
+ why(
282
+ f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}, \
283
+ likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950
284
+ )
285
+ return False
286
+ else: # Likely due to unsupported dtype.
287
+ why(
288
+ f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}. \
289
+ Reason: {not_implemented_op}" # noqa: G004, B950
290
+ )
291
+ return False
292
+
293
+ return True
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_env.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import logging
3
+ import shutil
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from torch._inductor.utils import clear_on_fresh_cache
8
+
9
+ from ... import config
10
+
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ @clear_on_fresh_cache
16
+ @functools.lru_cache(1)
17
+ def get_cuda_arch() -> Optional[str]:
18
+ try:
19
+ cuda_arch = config.cuda.arch
20
+ if cuda_arch is None:
21
+ # Get Compute Capability of the first Visible device
22
+ major, minor = torch.cuda.get_device_capability(0)
23
+ return str(major * 10 + minor)
24
+ return str(cuda_arch)
25
+ except Exception as e:
26
+ log.error("Error getting cuda arch: %s", e)
27
+ return None
28
+
29
+
30
+ @clear_on_fresh_cache
31
+ @functools.lru_cache(1)
32
+ def get_cuda_version() -> Optional[str]:
33
+ try:
34
+ cuda_version = config.cuda.version
35
+ if cuda_version is None:
36
+ cuda_version = torch.version.cuda
37
+ return cuda_version
38
+ except Exception as e:
39
+ log.error("Error getting cuda version: %s", e)
40
+ return None
41
+
42
+
43
+ @functools.cache
44
+ def nvcc_exist(nvcc_path: Optional[str] = "nvcc") -> bool:
45
+ return nvcc_path is not None and shutil.which(nvcc_path) is not None
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ import itertools
4
+ import logging
5
+ from collections import defaultdict
6
+ from dataclasses import dataclass
7
+ from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
8
+
9
+ from sympy import Expr, symbols
10
+
11
+ import torch._inductor.config as config
12
+ from torch import dtype as torch_dtype
13
+ from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
14
+ from torch._inductor.scheduler import BaseSchedulerNode
15
+ from torch._inductor.utils import do_bench_using_profiling, OrderedSet, Placeholder
16
+ from torch.utils._sympy.value_ranges import ValueRanges
17
+
18
+ from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE
19
+
20
+
21
+ if TYPE_CHECKING:
22
+ from .cuda_template import ArgInfo
23
+
24
+ from ...autotune_process import CUDABenchmarkRequest
25
+ from ...ir import (
26
+ Buffer,
27
+ ChoiceCaller,
28
+ CUDATemplateBuffer,
29
+ IRNode,
30
+ Layout,
31
+ PrimitiveInfoType,
32
+ TensorBox,
33
+ )
34
+ from ...utils import sympy_product
35
+ from ...virtualized import V
36
+ from ..common import (
37
+ CSEVariable,
38
+ IndentedBuffer,
39
+ Kernel,
40
+ OpOverrides,
41
+ WorkspaceArg,
42
+ WorkspaceZeroMode,
43
+ )
44
+ from ..cpp_utils import CppPrinter, DTYPE_TO_CPP
45
+
46
+
47
+ if TYPE_CHECKING:
48
+ from torch._inductor.codegen.cuda.cuda_template import CUDATemplate
49
+
50
+ log = logging.getLogger(__name__)
51
+
52
+ cexpr = CppPrinter().doprint
53
+
54
+
55
+ def _normalize_idx(index: int, total_length: int) -> int:
56
+ return index if index >= 0 else index + total_length
57
+
58
+
59
+ ValidLayoutSymbols = Literal["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"]
60
+ ValidLayoutAttrs = Literal["size", "stride"]
61
+
62
+
63
+ @dataclass(frozen=True)
64
+ class LayoutArg:
65
+ node: IRNode
66
+ symbol: ValidLayoutSymbols
67
+ attr: ValidLayoutAttrs
68
+ dim: int
69
+
70
+ def matches(self, node, attr, dim) -> bool:
71
+ return self.node == node and self.attr == attr and self.dim == dim
72
+
73
+
74
+ class CUDAKernel(Kernel):
75
+ """
76
+ Baseclass for CUDA / Cutlass based Kernels
77
+ """
78
+
79
+ overrides = OpOverrides # type: ignore[assignment]
80
+
81
+ def __init__(self, *args, **kwargs) -> None:
82
+ super().__init__(*args, **kwargs)
83
+ self.layout_args: dict[str, list[LayoutArg]] = defaultdict(list)
84
+ self.size_args: list[Union[Expr, int]] = []
85
+ # Mapping from arg name to IRNode.
86
+ self.named_nodes: dict[str, IRNode] = {}
87
+
88
+ def find_symbol(
89
+ self, node: IRNode, attr: ValidLayoutAttrs, dim: int
90
+ ) -> Optional[str]:
91
+ arg = self.find_layout_arg(node, attr, dim)
92
+ return arg.symbol if arg else None
93
+
94
+ def find_layout_arg(
95
+ self, node: IRNode, attr: ValidLayoutAttrs, dim: int
96
+ ) -> Optional[LayoutArg]:
97
+ matches = [
98
+ arg
99
+ for arg in itertools.chain.from_iterable(self.layout_args.values())
100
+ if arg.matches(node, attr, dim)
101
+ ]
102
+ if len(matches) >= 1:
103
+ # Verify all matches have the same node, attribute, and dimension
104
+ # And if they come from the same node, whichever symbol we use is fine.
105
+ # if in runtime the logic changes, this would trigger guard
106
+ first_match = matches[0]
107
+ if not all(
108
+ match.node == first_match.node
109
+ and match.attr == first_match.attr
110
+ and match.dim == first_match.dim
111
+ for match in matches
112
+ ):
113
+ raise AssertionError("All matching layout args should be identical")
114
+ return first_match
115
+ return None
116
+
117
+ def add_layout_arg(
118
+ self, symbol: ValidLayoutSymbols, node: IRNode, attr: ValidLayoutAttrs, dim: int
119
+ ):
120
+ arg = LayoutArg(node, symbol, attr, dim)
121
+ self.layout_args[symbol].append(arg)
122
+
123
+ def init_layout_args(self) -> None:
124
+ X = self.named_nodes["X"]
125
+ W = self.named_nodes["W"]
126
+ Y = self.named_nodes["Y"]
127
+ Bias = self.named_nodes.get("Bias", None)
128
+ x_mdim = _normalize_idx(-2, len(X.get_size()))
129
+ x_kdim = _normalize_idx(-1, len(X.get_size()))
130
+ w_kdim = _normalize_idx(-2, len(W.get_size()))
131
+ w_ndim = _normalize_idx(-1, len(W.get_size()))
132
+ y_mdim = _normalize_idx(-2, len(Y.get_size()))
133
+ y_ndim = _normalize_idx(-1, len(Y.get_size()))
134
+ self.add_layout_arg("M", X, "size", x_mdim)
135
+ self.add_layout_arg("K", X, "size", x_kdim)
136
+ self.add_layout_arg("K", W, "size", w_kdim)
137
+ self.add_layout_arg("N", W, "size", w_ndim)
138
+ self.add_layout_arg("M", Y, "size", y_mdim)
139
+ self.add_layout_arg("N", Y, "size", y_ndim)
140
+ if len(X.get_size()) > 2:
141
+ self.add_layout_arg("B", X, "size", 0)
142
+
143
+ lda_dim = self.find_ld_idx(X)
144
+ ldb_dim = self.find_ld_idx(W)
145
+ ldc_dim = self.find_ld_idx(Bias) if Bias else None
146
+ ldd_dim = self.find_ld_idx(Y)
147
+ self.add_layout_arg("lda", X, "stride", lda_dim)
148
+ self.add_layout_arg("ldb", W, "stride", ldb_dim)
149
+ if Bias is not None and ldc_dim is not None:
150
+ self.add_layout_arg("ldc", Bias, "stride", ldc_dim)
151
+ self.add_layout_arg("ldd", Y, "stride", ldd_dim)
152
+
153
+ def get_layout_args(self) -> tuple[Union[Expr, int], ...]:
154
+ X = self.named_nodes["X"]
155
+ W = self.named_nodes["W"]
156
+ Y = self.named_nodes["Y"]
157
+ Bias = self.named_nodes.get("Bias", None)
158
+ mdim = _normalize_idx(-2, len(X.get_size()))
159
+ ndim = _normalize_idx(-1, len(W.get_size()))
160
+ kdim = _normalize_idx(-1, len(X.get_size()))
161
+
162
+ def get_ld(node) -> Union[Expr, int]:
163
+ dim = self.find_ld_idx(node)
164
+ return node.get_stride()[dim]
165
+
166
+ M = X.get_size()[mdim]
167
+ N = W.get_size()[ndim]
168
+ K = X.get_size()[kdim]
169
+ B = X.get_size()[0] if len(X.get_size()) > 2 else 1
170
+ LDA = get_ld(X)
171
+ LDB = get_ld(W)
172
+ LDC = get_ld(Bias) if Bias else 0
173
+ LDD = get_ld(Y)
174
+ return (M, N, K, B, LDA, LDB, LDC, LDD)
175
+
176
+ def get_dynamic_shape_args(self) -> list[Union[Expr, int]]:
177
+ return [*self.get_layout_args(), *self.size_args]
178
+
179
+ @staticmethod
180
+ def find_ld_idx(node: IRNode) -> int:
181
+ strides = node.get_stride()
182
+ # Handle 1D tensor case
183
+ if V.graph.sizevars.statically_known_equals(strides[-1], 1):
184
+ return _normalize_idx(-2, len(strides))
185
+
186
+ assert V.graph.sizevars.statically_known_equals(strides[-2], 1), strides[-2]
187
+ return _normalize_idx(-1, len(strides))
188
+
189
+
190
+ class CUDATemplateKernel(CUDAKernel):
191
+ """
192
+ Template kernels defined by CUDA / Cutlass in C++.
193
+ """
194
+
195
+ _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream"
196
+
197
+ def __init__(
198
+ self,
199
+ kernel_name: str,
200
+ runtime_arg_info: list["ArgInfo"],
201
+ runtime_arg_values: list[Any],
202
+ ) -> None:
203
+ """
204
+ Initializes a new instance of the CUDATemplateKernel class.
205
+
206
+ Args:
207
+ kernel_name (str): The name of the kernel.
208
+ """
209
+ super().__init__()
210
+ self.kernel_name = kernel_name
211
+ self.runtime_arg_info = runtime_arg_info
212
+ self.runtime_arg_values = runtime_arg_values
213
+
214
+ def check_not_null(self, node: IRNode) -> str:
215
+ """
216
+ Generates code to check that a node is not null.
217
+ """
218
+ if node is None:
219
+ return ""
220
+
221
+ size_str = self.size(node, 0, -1)
222
+ name_str = self.arg_name(node)
223
+ if name_str is None:
224
+ return ""
225
+
226
+ res = IndentedBuffer(initial_indent=2)
227
+ res.tabwidth = 1
228
+ res.splice(
229
+ f"""
230
+ {{
231
+ if (!{name_str}) {{
232
+ int64_t {name_str}_size = {size_str};
233
+ if ({name_str}_size > 0) {{
234
+ throw std::runtime_error("input {name_str} is null but size is not 0!");
235
+ }}
236
+ }}
237
+ }}
238
+ """
239
+ )
240
+ return res.getvalue()
241
+
242
+ def get_signature(self) -> str:
243
+ return self.signature
244
+
245
+ def def_kernel(
246
+ self,
247
+ inputs: list[IRNode],
248
+ outputs: list[IRNode],
249
+ names_str: str = "",
250
+ input_reorder: Optional[list[int]] = None,
251
+ ) -> str:
252
+ """
253
+ Hook called from template code to generate function definition and
254
+ needed args.
255
+
256
+ Args:
257
+ inputs: List of input IRNodes
258
+ outputs: List of output IRNodes
259
+ names_str: Comma separated list of input + output argument names.
260
+ input_reorder: The actual order of input nodes.
261
+ e.g. The template might have input argument defined as [X, W, Bias],
262
+ and the actual input passed into this template could be [Bias, X, W].
263
+ In this case, the `input_reorder` would be [2, 0, 1].
264
+ additional_size_args: Additional size arguments for epilogue inputs
265
+ """
266
+ names = [x.strip() for x in names_str.strip().split(",")]
267
+ if len(inputs) + len(outputs) != len(names):
268
+ raise RuntimeError(
269
+ f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}"
270
+ )
271
+
272
+ if input_reorder is not None:
273
+ assert len(inputs) == len(input_reorder)
274
+ else:
275
+ input_reorder = list(range(len(inputs)))
276
+
277
+ for idx in input_reorder:
278
+ name = names[idx]
279
+ node = inputs[idx]
280
+ if node is not None:
281
+ self.named_nodes[name] = node
282
+ self.args.input_buffers[node.get_name()] = name
283
+
284
+ free_symbols: OrderedSet[Expr] = OrderedSet()
285
+ for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs):
286
+ if node is not None:
287
+ self.named_nodes[name] = node
288
+ self.args.output_buffers[node.get_name()] = name
289
+
290
+ if name not in (
291
+ "X",
292
+ "W",
293
+ "Bias",
294
+ "Y",
295
+ ): # we handle these symbolic shapes explicitly
296
+ for expr in itertools.chain(node.get_size(), node.get_stride()):
297
+ if isinstance(expr, Expr):
298
+ for s in expr.free_symbols:
299
+ free_symbols.add(s) # type: ignore[arg-type]
300
+
301
+ arg_defs, *_ = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE)
302
+
303
+ self.init_layout_args()
304
+ size_vars = ["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"]
305
+ size_vars.extend(str(s) for s in free_symbols)
306
+ self.size_args.extend(free_symbols)
307
+ size_args = [f"const int {s}" for s in size_vars]
308
+
309
+ runtime_arg_decls = ",".join(
310
+ [f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info]
311
+ )
312
+ if runtime_arg_decls:
313
+ runtime_arg_decls += ", "
314
+
315
+ signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args)}, {runtime_arg_decls}{self._EXTRA_CPP_ARGS})"
316
+ self.signature = signature
317
+ return signature
318
+
319
+ def call_kernel(
320
+ self,
321
+ name: str,
322
+ node: "CUDATemplateBuffer", # type: ignore[name-defined]
323
+ ) -> None:
324
+ """
325
+ Generates code to call the kernel through V.graph.wrapper_code.
326
+ used from within torch._inductor.wrapper.PythonWrapperCodegen
327
+
328
+ name: Name of kernel function.
329
+ node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes
330
+ as well as all required inputs and outputs.
331
+ """
332
+ wrapper = V.graph.wrapper_code
333
+
334
+ arg_types: list[Any]
335
+ if V.graph.cpp_wrapper:
336
+ # Make sure we initialize these kernels since they're exported as
337
+ # C-style symbol names.
338
+ assert isinstance(wrapper, CppWrapperCpu)
339
+ wrapper.initialized_kernels[name] = self
340
+ # We always originally initialize name with "KERNEL_NAME". So, we
341
+ # we replace with the real kernel name passed as an arg to this function.
342
+ self.signature = self.signature.replace(str(Placeholder.KERNEL_NAME), name)
343
+ _, call_args, arg_types = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE)
344
+ else:
345
+ _, call_args, _, arg_types = self.args.python_argdefs()
346
+
347
+ dynamic_shape_args = self.get_dynamic_shape_args()
348
+ call_args.extend(dynamic_shape_args) # type: ignore[arg-type]
349
+ for arg in self.runtime_arg_values:
350
+ call_args.append(arg)
351
+ arg_types.extend("int" for _ in dynamic_shape_args)
352
+ for arg in self.runtime_arg_info:
353
+ arg_types.append(arg.ty)
354
+ # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
355
+ for i in range(len(call_args)):
356
+ if V.graph.is_unspec_arg(call_args[i]):
357
+ call_args[i] = call_args[i] + ".item()"
358
+ elif isinstance(arg_types[i], torch_dtype):
359
+ call_args[i] = (
360
+ call_args[i]
361
+ if V.graph.cpp_wrapper
362
+ else f"c_void_p({call_args[i]}.data_ptr())"
363
+ )
364
+
365
+ # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size.
366
+ # workspace_size should have already been retrieved prior to this call.
367
+ # workspace_size is here.
368
+ call_args.append("nullptr" if V.graph.cpp_wrapper else "None")
369
+ if V.graph.cpp_wrapper:
370
+ arg_types.append("size_t*")
371
+
372
+ if node.get_workspace_size() > 0:
373
+ ws = WorkspaceArg(
374
+ count=node.get_workspace_size(),
375
+ device=V.graph.get_current_device_or_throw(),
376
+ zero_mode=WorkspaceZeroMode.UNINITIALIZED,
377
+ outer_name=WorkspaceArg.unique_name(),
378
+ )
379
+ wrapper.generate_workspace_allocation(ws)
380
+ workspace = str(ws.outer_name)
381
+ call_args.append(
382
+ workspace
383
+ if V.graph.cpp_wrapper
384
+ else f"c_void_p({workspace}.data_ptr())"
385
+ )
386
+ else:
387
+ ws = None
388
+ call_args.append("nullptr" if V.graph.cpp_wrapper else "None")
389
+ if V.graph.cpp_wrapper:
390
+ arg_types.append("uint8_t*")
391
+
392
+ wrapper.generate_kernel_call(
393
+ name,
394
+ call_args,
395
+ triton=False,
396
+ arg_types=arg_types,
397
+ )
398
+ if ws:
399
+ wrapper.generate_workspace_deallocation(ws)
400
+
401
+ def dtype(self, node: IRNode) -> Optional[str]:
402
+ """
403
+ Generates code which represents dtype of a given node.
404
+ """
405
+
406
+ if node is None:
407
+ return "void"
408
+ return DTYPE_TO_CPP.get(node.get_layout().dtype)
409
+
410
+ def cutlass_dtype(self, node: IRNode, default_dtype="void") -> Optional[str]:
411
+ # Helper method, called into from CUTLASSGemmTemplate
412
+ if node is None:
413
+ return default_dtype
414
+ from torch._inductor.codegen.cuda.cuda_template import CUTLASSTemplate
415
+
416
+ return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]
417
+
418
+ def max_valid_index(self, node: IRNode, default=-1):
419
+ # Helper method, called into from CUTLASSGemmTemplate
420
+ if node is None:
421
+ return default
422
+ max_valid_offset = 0
423
+ for i in range(len(node.get_size())):
424
+ max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i]
425
+ return max_valid_offset
426
+
427
+ def offset(self, node: IRNode) -> str:
428
+ """
429
+ Generates code which represents offset of a given node.
430
+ """
431
+
432
+ if node is None:
433
+ return "0"
434
+ return str(node.get_layout().offset) # type: ignore[union-attr]
435
+
436
+ def ptr(self, node: IRNode) -> str:
437
+ """
438
+ Generates code which represents pointer of a given node.
439
+ """
440
+
441
+ if node is None:
442
+ return "nullptr"
443
+ arg_name = self.arg_name(node)
444
+ if arg_name is None:
445
+ return "nullptr"
446
+ offset = self.offset(node)
447
+ return arg_name if offset == "0" else f"{arg_name} + {offset}"
448
+
449
+ def size(
450
+ self,
451
+ node: IRNode,
452
+ start_index: int,
453
+ end_index: Optional[int] = None,
454
+ default_value: int = 0,
455
+ ) -> str:
456
+ """
457
+ Hook called from template code to get the size of an arg.
458
+ Generates code which represents size of a given node in [start_index, end_index).
459
+ If node is None, returns default_value.
460
+
461
+ TODO: Will add needed args to pass it in if it is dynamic.
462
+ """
463
+
464
+ if node is None:
465
+ return str(default_value)
466
+
467
+ start_index = _normalize_idx(start_index, len(node.get_size()))
468
+ if end_index is None:
469
+ end_index = start_index
470
+ end_index = _normalize_idx(end_index, len(node.get_size()))
471
+ sizes = [
472
+ self.find_symbol(node, "size", dim=i) or node.get_size()[i]
473
+ for i in range(start_index, end_index + 1)
474
+ ]
475
+ if len(sizes) == 0:
476
+ return str(default_value)
477
+
478
+ sizes = [symbols(v) if isinstance(v, str) else v for v in sizes]
479
+ val = sympy_product(sizes)
480
+ return val
481
+
482
+ def stride(self, node: IRNode, index: int, default_value: int = 0) -> str:
483
+ """
484
+ Hook called from template code to get the stride of an arg.
485
+ Generates code which represents stride of a given node at index.
486
+ If node is None, returns default_value.
487
+
488
+ TODO: Will add needed args to pass it in if it is dynamic.
489
+ """
490
+
491
+ if node is None:
492
+ return str(default_value)
493
+
494
+ index = _normalize_idx(index, len(node.get_size()))
495
+ if index < 0:
496
+ return str(default_value)
497
+
498
+ stride = node.get_stride()[index]
499
+ if V.graph.sizevars.statically_known_leq(stride, 1):
500
+ return str(stride)
501
+ return self.find_symbol(node, "stride", dim=index) or str(stride)
502
+
503
+ def batch_stride(self, node: IRNode, default_value: int = 0) -> str:
504
+ """
505
+ Hook called from template code to get the batch stride of an arg.
506
+ Returns 0 if batch dim is not present.
507
+
508
+ This method assumes that batch stride is the largest stride.
509
+ """
510
+
511
+ if node is None:
512
+ return str(default_value)
513
+
514
+ if len(node.get_size()) < 3:
515
+ return str(default_value)
516
+
517
+ batch_stride = node.get_stride()[0]
518
+ if V.graph.sizevars.statically_known_leq(batch_stride, 1):
519
+ return str(batch_stride)
520
+
521
+ return "{}*{}".format(
522
+ self.find_symbol(node, "size", dim=1) or node.get_size()[1],
523
+ self.find_symbol(node, "size", dim=2) or node.get_size()[2],
524
+ )
525
+
526
+ def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str:
527
+ """
528
+ Hook called from template code to get the row or column stride of an arg.
529
+ This is required by some CUTLASS 2.X APIs.
530
+ If the node is in row_major, it returns stride[-2].
531
+ If the node is in column_major, it returns stride[-1].
532
+
533
+ TODO: Will add needed args to pass it in if it is dynamic.
534
+ """
535
+
536
+ if node is None or len(node.get_stride()) < 2:
537
+ return str(default_value)
538
+
539
+ stride0 = node.get_stride()[-1]
540
+ stride1 = node.get_stride()[-2]
541
+ if stride0 == 1:
542
+ return cexpr(self.rename_indexing(stride1))
543
+ elif stride1 == 1:
544
+ return cexpr(self.rename_indexing(stride0))
545
+ else:
546
+ raise RuntimeError(
547
+ f"At least 1 stride should be 1. Strides: {node.get_stride()=}"
548
+ )
549
+
550
+ def load(self, name: str, index: Expr, mode: Any = None) -> CSEVariable:
551
+ """
552
+ Mock load function for memory planning to optimize allocations properly.
553
+ """
554
+ return self.create_cse_var(name, bounds=ValueRanges.unknown())
555
+
556
+ def store(self, name: str, index: Expr, value: Any, mode: Any = None) -> None:
557
+ """
558
+ Mock store function for memory planning to optimize allocations properly.
559
+ """
560
+ self.store_buffer_names.add(name)
561
+
562
+
563
+ class CUDATemplateCaller(ChoiceCaller):
564
+ """
565
+ CUDATemplateCaller
566
+
567
+ This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller.
568
+ Attributes:
569
+ name (str): The name of the caller.
570
+ category (str): The category of the caller.
571
+ bmreq (CUDABenchmarkRequest): The benchmark request for the caller.
572
+ template_buffer (CUDATemplateBuffer): The template buffer for the caller.
573
+ """
574
+
575
+ def __init__(
576
+ self,
577
+ name: str,
578
+ category: str,
579
+ input_nodes: list[Buffer],
580
+ layout: Layout,
581
+ make_kernel_render: Callable[
582
+ [CUDATemplateBuffer, Optional[list[BaseSchedulerNode]]],
583
+ tuple[CUDATemplateKernel, functools.partial[str]],
584
+ ],
585
+ bmreq: CUDABenchmarkRequest,
586
+ supports_epilogue_fusion: bool,
587
+ template: "CUDATemplate", # type: ignore[name-defined]
588
+ info_kwargs: Optional[
589
+ dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
590
+ ], # type: ignore[type-arg]
591
+ description: str,
592
+ ) -> None:
593
+ super().__init__(name, input_nodes, layout, description)
594
+ self.category = category
595
+ self.make_kernel_render = make_kernel_render
596
+ self.bmreq = bmreq
597
+ self.supports_epilogue_fusion = supports_epilogue_fusion
598
+ self.template = template
599
+ self.info_kwargs = info_kwargs
600
+
601
+ def precompile(self) -> None:
602
+ assert self.bmreq is not None
603
+ self.bmreq.precompile()
604
+
605
+ def benchmark(self, *args, out) -> float:
606
+ assert self.bmreq is not None
607
+ if config.profile_bandwidth_with_do_bench_using_profiling:
608
+ algo = self.bmreq.make_run_fn(*args, out=out)
609
+ return do_bench_using_profiling(algo)
610
+ return self.bmreq.benchmark(*args, out=out)
611
+
612
+ def __str__(self) -> str:
613
+ return f"CUDATemplateCaller(source_file={self.bmreq.source_file})"
614
+
615
+ def call_name(self) -> str:
616
+ return f"cuda_template_kernels.{self.name}"
617
+
618
+ def kernel_hash_key(self) -> str:
619
+ """
620
+ Return kernel hash key that does not depend on swizzle.
621
+ """
622
+ return "-".join(
623
+ [
624
+ self.category,
625
+ self.bmreq.hash_key,
626
+ ]
627
+ )
628
+
629
+ def hash_key(self) -> str:
630
+ """
631
+ Return kernel hash key that does not depend on swizzle.
632
+ """
633
+ return "-".join(
634
+ [
635
+ self.category,
636
+ self.bmreq.hash_key,
637
+ str(self.info_dict().get("swizzle")),
638
+ ]
639
+ )
640
+
641
+ def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]:
642
+ """Information returned here is logged to the autotune log file when that is enabled."""
643
+ if self.info_kwargs is not None and "op" in self.info_kwargs:
644
+ op: Any = self.info_kwargs["op"]
645
+ return {
646
+ "backend": "CUDA",
647
+ "op_type": type(op).__name__,
648
+ "op_conf_name": str(op.configuration_name()),
649
+ "op_arch": str(op.arch),
650
+ "tile_shape": str(op.tile_description.tile_shape),
651
+ "epilogue_schedule": str(op.epilogue_schedule),
652
+ "kernel_schedule": str(op.kernel_schedule),
653
+ "element_accumulator": str(op.accumulator_type()),
654
+ "op_name": str(op.procedural_name()),
655
+ "instruction_shape": str(
656
+ op.tile_description.math_instruction.instruction_shape
657
+ ),
658
+ "swizzle": str(self.info_kwargs["swizzle"]),
659
+ }
660
+ else:
661
+ return {"backend": "CUDA", "op_type": "unknown"}
662
+
663
+ def output_node(self) -> TensorBox:
664
+ self.bmreq.update_workspace_size()
665
+ return TensorBox.create(
666
+ CUDATemplateBuffer(
667
+ layout=self.layout,
668
+ inputs=self.input_nodes,
669
+ make_kernel_render=self.make_kernel_render,
670
+ workspace_size=self.bmreq.workspace_size,
671
+ supports_epilogue_fusion=self.supports_epilogue_fusion,
672
+ template=self.template,
673
+ )
674
+ )
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_template.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ import hashlib
4
+ import itertools
5
+ from dataclasses import dataclass
6
+ from typing import Any, Optional, TYPE_CHECKING
7
+ from typing_extensions import override
8
+ from unittest.mock import patch
9
+
10
+ import sympy
11
+
12
+ import torch
13
+ from torch._inductor.utils import Placeholder
14
+ from torch._logging import getArtifactLogger
15
+
16
+ from ...autotune_process import CUDABenchmarkRequest, TensorMeta
17
+ from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout
18
+ from ...utils import IndentedBuffer, unique
19
+ from ...virtualized import V
20
+ from ..common import KernelTemplate
21
+ from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel
22
+ from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from ...scheduler import BaseSchedulerNode # noqa: TC004
27
+ else:
28
+ BaseSchedulerNode = Any
29
+
30
+ GemmOperation = Any
31
+
32
+ autotuning_log = getArtifactLogger(__name__, "autotuning")
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class ArgInfo:
37
+ name: str
38
+ ty: str
39
+
40
+
41
+ class CUDATemplate(KernelTemplate):
42
+ index_counter = itertools.count()
43
+
44
+ def __init__(
45
+ self,
46
+ name: str,
47
+ input_nodes: list[Buffer],
48
+ layout: Layout,
49
+ input_reorder: Optional[list[int]] = None,
50
+ ) -> None:
51
+ """
52
+
53
+ Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly.
54
+
55
+ Args:
56
+ name (str): The name of the CUDATemplate object.
57
+ input_nodes (List[IRNode]): A list of input IRNodes.
58
+ layout (Layout): The layout of the output buffer / tensor.
59
+ input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes.
60
+
61
+ """
62
+ super().__init__(name)
63
+ self.input_nodes = input_nodes
64
+ self.output_node: Buffer = Buffer(name="buf_out", layout=layout)
65
+ self.input_reorder = input_reorder
66
+ self.layout = layout
67
+
68
+ @staticmethod
69
+ def supports_epilogue_fusion(op: GemmOperation) -> bool:
70
+ return False
71
+
72
+ def generate( # type: ignore[override]
73
+ self,
74
+ description,
75
+ **kwargs,
76
+ ) -> CUDATemplateCaller:
77
+ """
78
+ Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller
79
+ may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning.
80
+
81
+ Args:
82
+ kwargs: Additional keyword arguments.
83
+
84
+ Returns:
85
+ A CUDATemplateCaller object representing the generated CUDA template caller.
86
+ """
87
+ kernel_name = str(Placeholder.KERNEL_NAME)
88
+ with (
89
+ patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
90
+ CUDATemplateKernel(
91
+ kernel_name=kernel_name,
92
+ runtime_arg_info=self.get_runtime_arg_info(),
93
+ runtime_arg_values=self.get_runtime_arg_values(**kwargs),
94
+ ) as kernel,
95
+ ):
96
+ code = self.render(kernel=kernel, **kwargs)
97
+ _, call_args, _, _ = kernel.args.python_argdefs()
98
+ autotuning_log.debug("Generated Code:\n%s", code)
99
+ autotuning_log.debug(
100
+ "Args: cpp_argdefs: %s, python_argdefs: %s",
101
+ kernel.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE),
102
+ kernel.args.python_argdefs(),
103
+ )
104
+
105
+ input_reorder = (
106
+ self.input_reorder
107
+ if self.input_reorder is not None
108
+ else list(range(len(self.input_nodes)))
109
+ )
110
+ expected_args = list(
111
+ unique(self.input_nodes[idx].get_name() for idx in input_reorder)
112
+ )
113
+ expected_args.extend([self.output_node.get_name()])
114
+ assert list(call_args)[: len(expected_args)] == expected_args, (
115
+ call_args,
116
+ expected_args,
117
+ )
118
+ V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :]))
119
+ size_args = V.graph.sizevars.size_hints(kernel.get_dynamic_shape_args())
120
+ extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs))
121
+
122
+ kernel_hash = hashlib.sha256(code.encode("utf-8")).hexdigest()[:8]
123
+ kernel_name = f"cutlass_{kernel_hash}"
124
+ code = code.replace(self.name, kernel_name)
125
+
126
+ # create the BenchmarkRequest
127
+ bmreq = CUDABenchmarkRequest(
128
+ kernel_name=kernel_name,
129
+ input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
130
+ output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
131
+ extra_args=extra_args,
132
+ source_code=code,
133
+ )
134
+
135
+ # kwargs has "op" argument in case of CUTLASSGemmTemplate
136
+ op = kwargs["op"]
137
+ if not op:
138
+ supports_epilogue_fusion = False
139
+ else:
140
+ # epilogue fusion is only supported for TMA kernels
141
+ supports_epilogue_fusion = self.supports_epilogue_fusion(op)
142
+
143
+ def make_kernel_render(
144
+ template_node: CUDATemplateBuffer,
145
+ epilogue_nodes: Optional[list[BaseSchedulerNode]] = None,
146
+ ) -> tuple[CUDATemplateKernel, functools.partial[str]]:
147
+ assert supports_epilogue_fusion or not epilogue_nodes, (
148
+ "epilogue fusion is not supported for this kernel"
149
+ )
150
+ kernel = CUDATemplateKernel(
151
+ kernel_name=str(Placeholder.KERNEL_NAME),
152
+ runtime_arg_info=self.get_runtime_arg_info(),
153
+ runtime_arg_values=self.get_runtime_arg_values(**kwargs),
154
+ )
155
+ render = functools.partial(
156
+ self.render,
157
+ kernel=kernel,
158
+ template_buffer_node=template_node,
159
+ epilogue_nodes=epilogue_nodes,
160
+ **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate
161
+ )
162
+ return kernel, render
163
+
164
+ return CUDATemplateCaller(
165
+ kernel_name,
166
+ "cutlass_gemm",
167
+ self.input_nodes,
168
+ self.output_node.get_layout(),
169
+ make_kernel_render,
170
+ bmreq,
171
+ supports_epilogue_fusion,
172
+ self,
173
+ kwargs,
174
+ description,
175
+ )
176
+
177
+ def header(self) -> IndentedBuffer:
178
+ res = IndentedBuffer()
179
+ res.splice(
180
+ """
181
+ #include <exception>
182
+ #include <iostream>
183
+ #include <memory>
184
+ #include <random>
185
+ #include <vector>
186
+ """
187
+ )
188
+ return res
189
+
190
+ def globals(self) -> IndentedBuffer:
191
+ res = IndentedBuffer()
192
+ res.splice(
193
+ """
194
+ // We compile all models with -fvisibility=hidden. Any symbols that need to be
195
+ // exposed in the final shared library must be declared with PT_EXPORT to make
196
+ // them visible.
197
+ #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
198
+ #define PT_EXPORT __attribute__((__visibility__("default")))
199
+ #else
200
+ #ifdef _WIN32
201
+ #define PT_EXPORT __declspec(dllexport)
202
+ #else
203
+ #define PT_EXPORT
204
+ #endif
205
+ #endif
206
+ """
207
+ )
208
+ return res
209
+
210
+ def render(self, **kwargs) -> str:
211
+ raise NotImplementedError
212
+
213
+ def get_runtime_arg_info(self) -> list[ArgInfo]:
214
+ return []
215
+
216
+ def get_runtime_arg_values(self, **kwargs) -> list[Any]:
217
+ return []
218
+
219
+
220
+ class CUTLASSTemplate(CUDATemplate):
221
+ """
222
+ CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the
223
+ CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels.
224
+ """
225
+
226
+ def header(self) -> IndentedBuffer:
227
+ res = super().header()
228
+ res.splice(
229
+ """
230
+ #include "cute/tensor.hpp"
231
+ #include "cutlass/cutlass.h"
232
+ #include "cutlass/numeric_types.h"
233
+ #include "cutlass/tensor_ref.h"
234
+ #include "cutlass/util/host_tensor.h"
235
+ #include "cutlass/util/reference/host/tensor_fill.h"
236
+ #include "cutlass/util/reference/device/tensor_fill.h"
237
+ #include "cutlass/util/device_memory.h"
238
+ """
239
+ )
240
+ return res
241
+
242
+ def globals(self) -> IndentedBuffer:
243
+ res = super().globals()
244
+ res.splice(
245
+ """
246
+ using namespace cute;
247
+ #define CUTLASS_CHECK(status) \\
248
+ { \\
249
+ cutlass::Status error = status; \\
250
+ if (error != cutlass::Status::kSuccess) { \\
251
+ auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\
252
+ cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\
253
+ throw std::runtime_error(msg); \\
254
+ } \\
255
+ }
256
+
257
+ // Used as pass-through functor in EVT just for type casting / rounding
258
+ template <typename T>
259
+ struct identity_op {
260
+ CUTLASS_HOST_DEVICE
261
+ T operator()(T val) const { return val; }
262
+ };
263
+
264
+ """
265
+ )
266
+ return res
267
+
268
+ def cute_int(self, int_str: str, var_name: str) -> str:
269
+ res = ""
270
+ if int_str in ("1", "1L"):
271
+ res = "cute::Int<1>{}"
272
+ else:
273
+ res = int_str
274
+
275
+ return f"{res} /* {var_name} */"
276
+
277
+ _DTYPE_TO_CUTLASS = {
278
+ torch.float32: "float",
279
+ torch.float64: "double",
280
+ torch.float16: "cutlass::half_t",
281
+ torch.int32: "int32_t",
282
+ torch.int16: "int16_t",
283
+ torch.int8: "int8_t",
284
+ torch.uint8: "uint8_t",
285
+ torch.bool: "bool",
286
+ torch.bfloat16: "cutlass::bfloat16_t",
287
+ torch.float8_e4m3fn: "cutlass::float_e4m3_t",
288
+ }
289
+
290
+ _DTYPE_TO_CUTLASS_SPARSE_META = {
291
+ torch.int32: "uint32_t",
292
+ torch.int16: "uint16_t",
293
+ }
294
+
295
+ def cutlass_type_cast(self, node: IRNode, ptr: str) -> str:
296
+ if node is None:
297
+ return ptr
298
+ else:
299
+ return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})"
300
+
301
+ def cutlass_sparse_meta_type_cast(self, node: IRNode, ptr: str) -> str:
302
+ if node is None:
303
+ return ptr
304
+ else:
305
+ return (
306
+ f"({self._DTYPE_TO_CUTLASS_SPARSE_META.get(node.get_dtype())}*)({ptr})"
307
+ )
308
+
309
+ @override
310
+ def get_runtime_arg_info(self) -> list[ArgInfo]:
311
+ return [ArgInfo("swizzle", "const uint8_t")]
312
+
313
+ @override
314
+ def get_runtime_arg_values(self, **kwargs) -> list[Any]:
315
+ """
316
+ Helper method to retrieve runtime args from generate kwargs
317
+ """
318
+ return [kwargs[arg.name] for arg in self.get_runtime_arg_info()]
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_cache.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ import hashlib
4
+ import json
5
+ import logging
6
+ import os
7
+ import time
8
+ from typing import Any, Optional
9
+
10
+ import torch._inductor.config as config
11
+ from torch._inductor.codecache import cutlass_key
12
+ from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version
13
+ from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
14
+ from torch._inductor.runtime.cache_dir_utils import cache_dir
15
+ from torch._inductor.utils import clear_on_fresh_cache
16
+
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+
21
+ CONFIG_PREFIX: str = "configs"
22
+
23
+
24
+ def get_config_request_key(
25
+ arch: str,
26
+ cuda_version: str,
27
+ instantiation_level: str,
28
+ ) -> str:
29
+ """
30
+ Return a key for the full ops, based on cutlass key, arch, cuda version, and instantiation level.
31
+ """
32
+ hash_target = "-".join(
33
+ [
34
+ cutlass_key().hex(),
35
+ arch,
36
+ cuda_version,
37
+ instantiation_level,
38
+ ]
39
+ )
40
+ return hashlib.sha256(hash_target.encode("utf-8")).hexdigest()[0:8]
41
+
42
+
43
+ def _generate_config_filename(request_key: str) -> str:
44
+ """
45
+ Generate a filename for the full ops.
46
+ """
47
+ return f"{CONFIG_PREFIX}_{request_key}.json"
48
+
49
+
50
+ @clear_on_fresh_cache
51
+ @functools.cache
52
+ def maybe_fetch_ops() -> Optional[list[Any]]:
53
+ """
54
+ Fetch ops from databases.
55
+ """
56
+ if config.force_disable_caches:
57
+ return None
58
+
59
+ # setup
60
+ arch: str = get_cuda_arch()
61
+ # get_cuda_version might return "12.4.0" or "12.4"
62
+ # but we want to use "12.4"
63
+ version: str = ".".join(get_cuda_version().split(".")[:2])
64
+ instantiation_level: str = config.cuda.cutlass_instantiation_level
65
+
66
+ # filename and filepath
67
+ request_key: str = get_config_request_key(arch, version, instantiation_level)
68
+ filename: str = _generate_config_filename(request_key)
69
+ filepath: str = os.path.join(cache_dir(), filename)
70
+
71
+ # try fetch
72
+ serialized_ops: Optional[list[str]] = None
73
+ start_time = time.time()
74
+ if os.path.isfile(filepath):
75
+ # locally
76
+ try:
77
+ with open(filepath) as f:
78
+ serialized_ops = json.load(f)
79
+
80
+ assert isinstance(serialized_ops, list), (
81
+ f"Expected serialized ops is a list, got {type(serialized_ops)}"
82
+ )
83
+ except Exception as e:
84
+ log.warning(
85
+ "Failed to load CUTLASS config %s from local cache: %s",
86
+ filename,
87
+ e,
88
+ )
89
+ serialized_ops = None
90
+ elif config.is_fbcode():
91
+ from torch._inductor.fb.cutlass_remote_cache import (
92
+ maybe_fetch_cutlass_configs_from_remote,
93
+ )
94
+
95
+ # from remote
96
+ serialized_ops = maybe_fetch_cutlass_configs_from_remote(filepath)
97
+
98
+ if serialized_ops is None:
99
+ return None
100
+
101
+ # deserialize
102
+ serializer = get_cutlass_operation_serializer()
103
+ full_ops = [serializer.deserialize(x) for x in serialized_ops] # type: ignore[union-attr]
104
+ log.info("Loaded ops from %s cache in %.3fs", filename, time.time() - start_time)
105
+ return full_ops
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py ADDED
File without changes
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Union
2
+
3
+ from sympy import Expr
4
+
5
+ from torch._inductor.ir import (
6
+ ComputedBuffer,
7
+ InputBuffer,
8
+ is_contiguous_strides_for_shape,
9
+ )
10
+ from torch.utils._ordered_set import OrderedSet
11
+
12
+ from ..cutlass_utils import torch_dtype_to_cutlass_type, try_import_cutlass
13
+
14
+
15
+ EpilogueFunctor = Any # EpilogueFunctor local class defined in _trace
16
+ Buffer = Union[ComputedBuffer, InputBuffer]
17
+ CutlassTupleType = Any # cutlass.backend.c_types.tuple_factory_.<locals>.TupleType
18
+ CutlassVisitorType = Any # cutlass.backend.c_types.visitor_factory.<locals>.VisitorType
19
+ CutlassArgType = (
20
+ Any # Can be a CutlassTupleType, CutlassVisitorType, EmptyByte, or ctype.c_void_p
21
+ )
22
+
23
+
24
+ if try_import_cutlass():
25
+ import ast
26
+ import ctypes
27
+ import textwrap
28
+ from typing import Union
29
+
30
+ from cutlass.backend.c_types import ( # type: ignore[import-untyped, import-not-found]
31
+ EmptyByte,
32
+ )
33
+ from cutlass.backend.epilogue import ( # type: ignore[import-untyped, import-not-found]
34
+ dtype2ctype,
35
+ )
36
+ from cutlass.backend.evt import ( # type: ignore[import-untyped, import-not-found]
37
+ EpilogueFunctorVisitor,
38
+ )
39
+ from cutlass.backend.evt.backend.emitter_base import ( # type: ignore[import-untyped, import-not-found]
40
+ FusionCallbacks,
41
+ )
42
+ from cutlass.backend.evt.backend.sm90_emitter import ( # type: ignore[import-untyped, import-not-found]
43
+ CollectiveEpilogue,
44
+ )
45
+ from cutlass.backend.evt.frontend import ( # type: ignore[import-untyped, import-not-found]
46
+ PythonASTFrontend,
47
+ )
48
+ from cutlass.backend.evt.ir.tensor import ( # type: ignore[import-untyped, import-not-found]
49
+ Tensor as CutlassTensor,
50
+ )
51
+ from cutlass_library import (
52
+ DataType,
53
+ EpilogueScheduleType,
54
+ LayoutType,
55
+ TileDescription,
56
+ )
57
+
58
+ from torch._inductor.codegen.cuda import cuda_env
59
+ from torch._inductor.utils import IndentedBuffer
60
+
61
+ _CUTLASS_C_DTYPES = OrderedSet(dtype2ctype.values()) # type: ignore[var-annotated]
62
+
63
+ def create_example_tensors(
64
+ var_name_to_buffer_name: dict[str, str],
65
+ name_to_buffer: dict[str, Buffer],
66
+ size_hint_fn: Callable[[Union[Expr, int]], int],
67
+ ) -> dict[str, CutlassTensor]:
68
+ def cutlass_tensor_from_buffer(buffer: Buffer) -> CutlassTensor:
69
+ shape = buffer.get_layout().size
70
+ stride = buffer.get_layout().stride
71
+ shape = tuple(size_hint_fn(x) for x in shape)
72
+ stride = tuple(size_hint_fn(x) for x in stride)
73
+
74
+ is_row_major = is_contiguous_strides_for_shape(stride, shape)
75
+ is_column_major = is_contiguous_strides_for_shape(stride[::-1], shape[::-1])
76
+
77
+ if not is_row_major and not is_column_major:
78
+ raise RuntimeError(
79
+ f"Cannot create example tensor for {buffer.get_name()} with \
80
+ non-contiguous layout, received stride: {stride} and shape: {shape}"
81
+ )
82
+
83
+ return CutlassTensor(
84
+ shape=shape,
85
+ layout_tag=LayoutType.RowMajor
86
+ if is_row_major
87
+ else LayoutType.ColumnMajor,
88
+ element=torch_dtype_to_cutlass_type(buffer.get_layout().dtype),
89
+ )
90
+
91
+ return {
92
+ key: cutlass_tensor_from_buffer(name_to_buffer[name])
93
+ for key, name in var_name_to_buffer_name.items()
94
+ }
95
+
96
+ def trace(
97
+ fn_src: str,
98
+ example_tensors: dict[str, CutlassTensor],
99
+ accum_type: DataType,
100
+ output_type: DataType,
101
+ tile_description: TileDescription,
102
+ epilogue_schedule: EpilogueScheduleType,
103
+ name_to_buffer: dict[str, Buffer],
104
+ size_hint_fn: Callable[[Union[Expr, int]], int],
105
+ **kwargs: dict[str, Any],
106
+ ) -> tuple[str, str, str]:
107
+ cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type]
108
+ assert cuda_arch >= 90, "Only SM90+ is supported for EVT"
109
+ epilogue_functor = _trace(fn_src, example_tensors, cuda_arch, **kwargs)
110
+ visitor = EpilogueFunctorVisitor(cuda_arch, epilogue_functor)
111
+ fusion_callbacks = FusionCallbacks(visitor.graph, cuda_arch, emit_CD=False)
112
+ collective_epilogue = CollectiveEpilogue(
113
+ tile_description,
114
+ epilogue_schedule,
115
+ accum_type,
116
+ output_type,
117
+ fusion_callbacks,
118
+ )
119
+ evt_name, evt_code = collective_epilogue.emit()
120
+ evt_args = _render_argument_type(epilogue_functor, name_to_buffer, size_hint_fn)
121
+ return evt_name, evt_args, evt_code
122
+
123
+ # Based off of
124
+ # https://github.com/NVIDIA/cutlass/blob/df18f5e4f5de76bed8be1de8e4c245f2f5ec3020/python/cutlass/epilogue/epilogue.py#L117
125
+ # This is modified to enable directly passing the source code of the epilogue vs getting it from a bona-fide python function
126
+ # The reason for this is that inspect.getsource does not work with functions defined at runtime via exec/eval
127
+ def _trace(
128
+ fn_src: str, example_tensors: dict[str, CutlassTensor], cc: int, **kwargs: Any
129
+ ) -> EpilogueFunctor:
130
+ class EpilogueFunctor(PythonASTFrontend):
131
+ def __init__(self, cc: int, **kwargs: Any):
132
+ self.source = textwrap.dedent(fn_src)
133
+ super().__init__(cc, **kwargs)
134
+
135
+ def parse(self, example_inputs: dict[str, CutlassTensor]) -> None:
136
+ self.example_inputs = example_inputs
137
+ self.ast = ast.parse(self.source)
138
+ self.visit(self.ast)
139
+
140
+ cc = int(cuda_env.get_cuda_arch())
141
+ epilogue_functor = EpilogueFunctor(cc=cc, **kwargs)
142
+ epilogue_functor.trace(example_tensors)
143
+ return epilogue_functor
144
+
145
+ def _render_argument_type(
146
+ epilogue_functor: EpilogueFunctor,
147
+ name_to_buffer: dict[str, Buffer],
148
+ size_hint_fn: Callable[[Union[Expr, int]], int],
149
+ ) -> str:
150
+ epilogue_thread_type = epilogue_functor.epilogue_thread_type
151
+
152
+ # Fragile, but this is the only way to guarantee t is expected type because t is a local class
153
+ def is_nested_visitor_type(t: type) -> bool:
154
+ return (
155
+ ".".join([t.__module__, t.__qualname__])
156
+ == "cutlass.backend.c_types.visitor_factory.<locals>.VisitorType"
157
+ )
158
+
159
+ buffer = IndentedBuffer()
160
+ with buffer.set_tabwidth(2):
161
+
162
+ def render_argument_type(name: str, t: CutlassArgType) -> None:
163
+ if issubclass(t, ctypes.c_byte):
164
+ buffer.writeline(f"{{}}, /* {name} */")
165
+ else:
166
+ fields = [
167
+ (
168
+ fname,
169
+ _get_arg_from_node(ty, name_to_buffer[name], size_hint_fn),
170
+ )
171
+ for fname, ty in t._fields_
172
+ ]
173
+ field_strs = [
174
+ f"/* {fname} */ {str(field)}" for fname, field in fields
175
+ ]
176
+ buffer.writeline(f"{{{', '.join(field_strs)}}}, /* {name} */")
177
+
178
+ def render_thread_type(name: str, t: CutlassArgType) -> None:
179
+ if is_nested_visitor_type(t):
180
+ buffer.writeline(f"{{ /* {name} */")
181
+ with buffer.indent():
182
+ for name, inner_t in t._fields_:
183
+ render_thread_type(name, inner_t)
184
+ buffer.writeline("},")
185
+ else:
186
+ render_argument_type(name, t)
187
+
188
+ # unroll the recursion once to address special case formatting
189
+ # namely, no ending comma and no indentation for the outermost thread type
190
+ buffer.writeline("{ /* thread */")
191
+ with buffer.indent(3):
192
+ if is_nested_visitor_type(epilogue_thread_type):
193
+ with buffer.indent():
194
+ for name, inner_t in epilogue_thread_type._fields_:
195
+ render_thread_type(name, inner_t)
196
+ else:
197
+ render_argument_type("thread", epilogue_thread_type)
198
+ buffer.writeline("}")
199
+
200
+ return buffer.getvalue()
201
+
202
+ def _get_arg_from_node(
203
+ arg_ty: type, node: Buffer, size_hint_fn: Callable[[Union[Expr, int]], int]
204
+ ) -> str:
205
+ from ..cuda_template import CUTLASSTemplate
206
+
207
+ # Today, arguments are either a pointer to the
208
+ # node's memory, a stride tuple, the datatype
209
+ # Once again, need to check for local class type for stride tuple
210
+ if (
211
+ str(arg_ty)
212
+ == "<class 'cutlass.backend.c_types.tuple_factory_.<locals>.TupleType'>"
213
+ ):
214
+ DEFAULT_STRIDE_LEN = 3
215
+ assert len(node.get_layout().stride) <= DEFAULT_STRIDE_LEN
216
+ stride = [size_hint_fn(x) for x in node.get_layout().stride]
217
+ for _ in range(DEFAULT_STRIDE_LEN - len(stride)):
218
+ stride.append(0)
219
+
220
+ def render_stride(x: int) -> str:
221
+ # Handle EBO for 0 and 1
222
+ if x == 0:
223
+ return "_0{}"
224
+ elif x == 1:
225
+ return "_1{}"
226
+ else:
227
+ return str(x)
228
+
229
+ return f"{{{', '.join([render_stride(x) for x in stride])}}}"
230
+
231
+ elif issubclass(arg_ty, ctypes.c_void_p):
232
+ return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) {node.get_name()}"
233
+ elif (
234
+ arg_ty in _CUTLASS_C_DTYPES
235
+ ): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently
236
+ return f"{CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}(0)"
237
+ elif issubclass(arg_ty, EmptyByte):
238
+ return "{}"
239
+
240
+ raise NotImplementedError(f"Unsupported arg type: {arg_ty}")
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+ from ..cutlass_utils import try_import_cutlass
3
+
4
+
5
+ # copied / modified from original at
6
+ # https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658
7
+
8
+ if try_import_cutlass():
9
+ import enum
10
+
11
+ from cutlass_library.gemm_operation import * # noqa: F401, F403
12
+ from cutlass_library.library import * # noqa: F401, F403
13
+
14
+ _LOGGER = logging.getLogger(__name__)
15
+
16
+ class EmitGemmUniversal3xInstanceWithEVT:
17
+ """Responsible for emitting a CUTLASS 3.x template definition"""
18
+
19
+ def __init__(self, operation_suffix="", evt_name=None):
20
+ self.operation_suffix = operation_suffix
21
+ self.includes = [
22
+ "cutlass/cutlass.h",
23
+ "cutlass/gemm/gemm.h",
24
+ "cutlass/numeric_types.h",
25
+ "cutlass/gemm/kernel/gemm_universal.hpp",
26
+ "cutlass/gemm/collective/collective_builder.hpp",
27
+ "cutlass/epilogue/collective/collective_builder.hpp",
28
+ ]
29
+ self.builtin_epilogue_functor_template = """${epilogue_functor}<
30
+ ${element_d},
31
+ ${element_epilogue},
32
+ ${element_c},
33
+ ${element_epilogue}
34
+ >"""
35
+
36
+ self.evt_name = evt_name
37
+ self.gemm_template = """
38
+ using ${operation_name}_epilogue =
39
+ typename cutlass::epilogue::collective::CollectiveBuilder<
40
+ ${arch}, ${opcode_class_epi},
41
+ cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
42
+ cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
43
+ ${epi_tile_mn},
44
+ ${element_accumulator}, ${element_epilogue},
45
+ ${element_c}, ${layout_c}, ${align_c},
46
+ ${element_d}, ${layout_d}, ${align_d},
47
+ ${epilogue_schedule},
48
+ ${epilogue_functor}
49
+ >::CollectiveOp;
50
+
51
+ ${mixed_dtype_prepare_code}
52
+
53
+ using ${operation_name}_mainloop =
54
+ typename cutlass::gemm::collective::CollectiveBuilder<
55
+ ${arch}, ${opcode_class_main},
56
+ ${element_a}, ${layout_a}, ${align_a},
57
+ ${element_b}, ${layout_b}, ${align_b},
58
+ ${element_accumulator},
59
+ cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
60
+ cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
61
+ ${stages},
62
+ ${kernel_schedule}
63
+ >::CollectiveOp;
64
+
65
+ // Gemm operator ${operation_name}
66
+ using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
67
+ ${problem_shape},
68
+ ${operation_name}_mainloop,
69
+ ${operation_name}_epilogue,
70
+ ${tile_scheduler}>;
71
+
72
+ // Define named type
73
+ struct ${operation_name} :
74
+ public ${operation_name}_base { };
75
+
76
+ """
77
+
78
+ #
79
+ def instance_template(self):
80
+ return """
81
+ ${compile_guard_start}
82
+ {
83
+ using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>;
84
+ manifest.append(
85
+ new ${gemm_kind}<GemmKernel>("${operation_name}"));
86
+ }
87
+ ${compile_guard_end}
88
+ """
89
+
90
+ def emit_block_scale_epilogue_functor(self, operation):
91
+ block_scaled_template = """
92
+ ${epilogue_functor}<
93
+ ${epi_vs},
94
+ ${element_d},
95
+ ${element_accumulator},
96
+ ${element_sfd},
97
+ ${layout_sfd},
98
+ ${element_c},
99
+ ${element_scalar}
100
+ >
101
+ """
102
+ block_scaled_values = {
103
+ "epi_vs": str(operation.ScaleFactorVectorSize),
104
+ "element_d": str(DataTypeTag[operation.D.element]),
105
+ "element_sfd": str(DataTypeTag[operation.ScaleFactorD.element]),
106
+ "layout_sfd": LayoutTag[operation.ScaleFactorD.layout],
107
+ "epilogue_functor": EpilogueFunctor3xTag[
108
+ EpilogueFunctor3x.LinearCombinationBlockScaleFactor
109
+ ],
110
+ "element_accumulator": str(DataTypeTag[operation.accumulator_type()]),
111
+ "element_scalar": str(DataTypeTag[operation.accumulator_type()]),
112
+ "element_c": str(DataTypeTag[operation.C.element]),
113
+ }
114
+ return SubstituteTemplate(block_scaled_template, block_scaled_values)
115
+
116
+ @staticmethod
117
+ def pointerize_if_grouped(operation, layout):
118
+ return layout if not is_grouped(operation.gemm_kind) else layout + "* "
119
+
120
+ @staticmethod
121
+ def problem_shape(operation):
122
+ gemm_shape_type = "cute::Shape<int,int,int,int>"
123
+ grouped_gemm_shape_type = "cute::Shape<int,int,int>"
124
+ grouped_gemm_shape_type = (
125
+ "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">"
126
+ )
127
+
128
+ return (
129
+ gemm_shape_type
130
+ if not is_grouped(operation.gemm_kind)
131
+ else grouped_gemm_shape_type
132
+ )
133
+
134
+ def emit(self, operation):
135
+ """Given a gem operation, emits a template definition of the operation"""
136
+
137
+ opcode_class_main = operation.tile_description.math_instruction.opcode_class
138
+ opcode_class_epi = opcode_class_main
139
+
140
+ tile_shape = operation.tile_description.tile_shape
141
+ instruction_shape = (
142
+ operation.tile_description.math_instruction.instruction_shape
143
+ )
144
+ cluster_m = operation.tile_description.cluster_shape[0]
145
+ cluster_n = operation.tile_description.cluster_shape[1]
146
+
147
+ tile_shape_m, tile_shape_n, tile_shape_k = tile_shape
148
+
149
+ # account for static/dynamic cluster shapes
150
+ cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0]
151
+ cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1]
152
+
153
+ # Shape passed to epilogue builder
154
+ is_sm100_kernel = operation.arch == 100
155
+ if is_sm100_kernel:
156
+ cta_m_per_mma_instruction = (
157
+ 2 if "2sm" in operation.procedural_name() else 1
158
+ )
159
+ if cluster_m <= 0:
160
+ cta_m = cta_m // cta_m_per_mma_instruction
161
+
162
+ if opcode_class_main in [
163
+ OpcodeClass.TensorOp,
164
+ OpcodeClass.BlockScaledTensorOp,
165
+ ]:
166
+ tile_shape_m = instruction_shape[0]
167
+ tile_shape_n = instruction_shape[1]
168
+
169
+ # stage count set to zero indicates builder automatic stage selection
170
+ if operation.tile_description.stages > 0:
171
+ stage_count_string = f"cutlass::gemm::collective::StageCount<\
172
+ {str(operation.tile_description.stages)}>"
173
+ else:
174
+ stage_count_string = (
175
+ f"cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(\
176
+ sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage))>"
177
+ )
178
+
179
+ epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto"
180
+
181
+ (
182
+ instance_layout_A,
183
+ instance_layout_B,
184
+ instance_layout_C,
185
+ instance_layout_D,
186
+ ) = (
187
+ operation.A.layout,
188
+ operation.B.layout,
189
+ operation.C.layout,
190
+ operation.D.layout,
191
+ )
192
+
193
+ # 3.0 profiler integration only supports trivial epilogues for now
194
+ epilogue_vector_length = 1
195
+
196
+ # Support built-in epilogue functors or user-defined functions
197
+ if isinstance(operation.epilogue_functor, enum.Enum):
198
+ values = {
199
+ "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
200
+ "epilogue_functor": EpilogueFunctor3xTag[
201
+ operation.epilogue_functor
202
+ ],
203
+ }
204
+ epilogue_functor = SubstituteTemplate(
205
+ self.builtin_epilogue_functor_template, values
206
+ )
207
+
208
+ if (
209
+ is_block_scaled(operation.gemm_kind)
210
+ and operation.ScaleFactorD.element != DataType.void
211
+ ):
212
+ epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
213
+ else:
214
+ epilogue_functor = self.epilogue_functor.emit_declaration()
215
+
216
+ if (
217
+ is_block_scaled(operation.gemm_kind)
218
+ and operation.ScaleFactorD.element != DataType.void
219
+ ):
220
+ epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
221
+
222
+ #
223
+ # Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder,
224
+ # e.g. cute::tuple<Element, Transform>, Transform : cute::identity / cute::conjugate.
225
+ element_a = (
226
+ DataTypeTag[operation.A.element]
227
+ if not operation.is_complex()
228
+ else f"cute::tuple<{str(DataTypeTag[operation.A.element])},\
229
+ {str(ComplexTransformTag3x[operation.A.complex_transform])}>"
230
+ )
231
+ element_b = (
232
+ DataTypeTag[operation.B.element]
233
+ if not operation.is_complex()
234
+ else f"cute::tuple<{str(DataTypeTag[operation.B.element])},\
235
+ {str(ComplexTransformTag3x[operation.B.complex_transform])}>"
236
+ )
237
+ epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
238
+
239
+ if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
240
+ is_no_smem_epilogue = operation.epilogue_schedule in [
241
+ EpilogueScheduleType.NoSmemWarpSpecialized1Sm,
242
+ EpilogueScheduleType.NoSmemWarpSpecialized2Sm,
243
+ ]
244
+ grouped = is_grouped(operation.gemm_kind)
245
+ if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(
246
+ KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped
247
+ ):
248
+ epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
249
+ if not is_no_smem_epilogue:
250
+ epilogue_schedule_type = EpilogueScheduleTag[
251
+ to_grouped_schedule(
252
+ EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped
253
+ )
254
+ ]
255
+ if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(
256
+ KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped
257
+ ):
258
+ epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
259
+ if not is_no_smem_epilogue:
260
+ epilogue_schedule_type = EpilogueScheduleTag[
261
+ to_grouped_schedule(
262
+ EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped
263
+ )
264
+ ]
265
+ element_a = f"cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>"
266
+ element_b = f"cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>"
267
+
268
+ operation_name_str = operation.procedural_name()
269
+ layout_a_str = LayoutTag[instance_layout_A]
270
+ layout_b_str = LayoutTag[instance_layout_B]
271
+ mixed_dtype_prepare_code = ""
272
+ if operation.mixed_input_mode is not None:
273
+ A_dtype = operation.A.element
274
+ B_dtype = operation.B.element
275
+ A_dtype_bits = DataTypeSize[A_dtype]
276
+ B_dtype_bits = DataTypeSize[B_dtype]
277
+ is_A_dtype_narrow = A_dtype_bits < B_dtype_bits
278
+ if is_A_dtype_narrow:
279
+ narrow_dtype, wide_dtype = (A_dtype, B_dtype)
280
+ narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits)
281
+ else:
282
+ narrow_dtype, wide_dtype = (B_dtype, A_dtype)
283
+ narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits)
284
+
285
+ narrow_tag = DataTypeTag[narrow_dtype]
286
+ wide_tag = DataTypeTag[wide_dtype]
287
+ scale_tag = DataTypeTag[wide_dtype]
288
+ zero_tag = DataTypeTag[wide_dtype]
289
+
290
+ do_shuffle = False
291
+ value_shuffle_str = ""
292
+ if narrow_dtype_bits == 4 and wide_dtype_bits == 16:
293
+ value_shuffle_str = "cute::Layout<cute::Shape<cute::_2,cute::_4>, \
294
+ cute::Stride<cute::_4,cute::_1>>"
295
+ do_shuffle = True
296
+ if narrow_dtype_bits == 8 and wide_dtype_bits == 16:
297
+ value_shuffle_str = "cute::Layout<cute::Shape<cute::_2,cute::_2>, \
298
+ cute::Stride<cute::_2,cute::_1>>"
299
+ do_shuffle = True
300
+ do_shuffle = operation.mixed_input_shuffle and do_shuffle
301
+
302
+ if do_shuffle:
303
+ if is_A_dtype_narrow:
304
+ stride_narrow_str = (
305
+ f"cutlass::detail::TagToStrideA_t<{layout_a_str}>"
306
+ )
307
+ layout_a_str = f"{operation_name_str}_LayoutNarrowReordered"
308
+ else:
309
+ stride_narrow_str = (
310
+ f"cutlass::detail::TagToStrideB_t<{layout_b_str}>"
311
+ )
312
+ layout_b_str = f"{operation_name_str}_LayoutNarrowReordered"
313
+ # The {operation_name_str}_ prefixs in mixed_dtype_prepare_code and
314
+ # layout_{a, b}_str are to prevent errors in Windows platform unity build
315
+ mixed_dtype_prepare_code = f"""
316
+ using {operation_name_str}_StrideNarrow = {stride_narrow_str};
317
+ using {operation_name_str}_ValueShuffle = {value_shuffle_str};
318
+ static constexpr int {operation_name_str}_NumShuffleAtoms = 1;
319
+ using {operation_name_str}_MmaAtomShape = \
320
+ cute::Layout<cute::Shape<cute::_1, cute::Int<{operation_name_str}_NumShuffleAtoms>>>;
321
+ using {operation_name_str}_LayoutAtomQuant = \
322
+ decltype(cutlass::compute_memory_reordering_atom<{wide_tag}, {operation_name_str}_MmaAtomShape, \
323
+ {operation_name_str}_ValueShuffle>());
324
+ using {operation_name_str}_LayoutNarrowReordered = \
325
+ decltype(cute::tile_to_shape({operation_name_str}_LayoutAtomQuant{{}}, \
326
+ cute::Layout<cute::Shape<int,int,int>, {operation_name_str}_StrideNarrow>{{}}));
327
+ """
328
+
329
+ mixed_input_modes_to_element = {
330
+ MixedInputMode.ConvertOnly: narrow_tag,
331
+ MixedInputMode.ScaleOnly: f"cute::tuple<{narrow_tag}, {scale_tag}>",
332
+ MixedInputMode.ScaleWithZeroPoint: f"cute::tuple<{narrow_tag}, {scale_tag}, {zero_tag}>",
333
+ }
334
+ narrow_element = mixed_input_modes_to_element.get(
335
+ operation.mixed_input_mode, narrow_tag
336
+ )
337
+
338
+ if narrow_dtype == DataType.s4 and (
339
+ wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2
340
+ ):
341
+ narrow_element = (
342
+ f"cute::tuple<{narrow_tag}, cutlass::Array<{scale_tag}, 8>>"
343
+ )
344
+
345
+ if is_A_dtype_narrow:
346
+ element_a = narrow_element
347
+ else:
348
+ element_b = narrow_element
349
+
350
+ if self.evt_name:
351
+ epilogue_functor = self.evt_name
352
+
353
+ values = {
354
+ "operation_name": operation_name_str,
355
+ "operation_suffix": self.operation_suffix,
356
+ "problem_shape": self.problem_shape(operation),
357
+ "element_a": element_a,
358
+ "layout_a": self.pointerize_if_grouped(operation, layout_a_str),
359
+ "element_b": element_b,
360
+ "layout_b": self.pointerize_if_grouped(operation, layout_b_str),
361
+ "element_c": DataTypeTag[operation.C.element],
362
+ "layout_c": self.pointerize_if_grouped(
363
+ operation, LayoutTag[instance_layout_C]
364
+ ),
365
+ "element_d": DataTypeTag[operation.D.element],
366
+ "layout_d": self.pointerize_if_grouped(
367
+ operation, LayoutTag[instance_layout_D]
368
+ ),
369
+ "element_accumulator": DataTypeTag[operation.accumulator_type()],
370
+ "opcode_class_main": OpcodeClassTag[opcode_class_main],
371
+ "opcode_class_epi": OpcodeClassTag[opcode_class_epi],
372
+ "arch": f"cutlass::arch::Sm{operation.arch}",
373
+ "tile_shape_m": str(tile_shape_m),
374
+ "tile_shape_n": str(tile_shape_n),
375
+ "tile_shape_k": str(tile_shape_k),
376
+ "cluster_shape_m": "cute::_"
377
+ + str(operation.tile_description.cluster_shape[0])
378
+ if operation.tile_description.cluster_shape[0] > 0
379
+ else "int",
380
+ "cluster_shape_n": "cute::_"
381
+ + str(operation.tile_description.cluster_shape[1])
382
+ if operation.tile_description.cluster_shape[1] > 0
383
+ else "int",
384
+ "cluster_shape_k": "cute::_"
385
+ + str(operation.tile_description.cluster_shape[2])
386
+ if operation.tile_description.cluster_shape[2] > 0
387
+ else "int",
388
+ "instruction_shape_m": str(instruction_shape[0]),
389
+ "instruction_shape_n": str(instruction_shape[1]),
390
+ "instruction_shape_k": str(instruction_shape[2]),
391
+ "kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]),
392
+ "epilogue_schedule": str(epilogue_schedule_type),
393
+ "epi_tile_mn": epi_tile_mn,
394
+ "epilogue_functor": epilogue_functor,
395
+ "stages": stage_count_string,
396
+ "align_a": str(operation.A.alignment),
397
+ "align_b": str(operation.B.alignment),
398
+ "align_c": str(operation.C.alignment),
399
+ "align_d": str(operation.C.alignment),
400
+ "transform_a": ComplexTransformTag[operation.A.complex_transform],
401
+ "transform_b": ComplexTransformTag[operation.B.complex_transform],
402
+ "math_operation": MathOperationTag[
403
+ operation.tile_description.math_instruction.math_operation
404
+ ],
405
+ "epilogue_vector_length": str(epilogue_vector_length),
406
+ "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
407
+ "tile_scheduler": str(TileSchedulerTag[operation.tile_scheduler]),
408
+ "mixed_dtype_prepare_code": mixed_dtype_prepare_code,
409
+ }
410
+
411
+ return SubstituteTemplate(self.gemm_template, values)