Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +177 -0
- .venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/__init__.cpython-312.pyc +0 -0
- .venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-312.pyc +0 -0
- .venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/__init__.cpython-312.pyc +0 -0
- .venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/python.cpython-312.pyc +0 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/__init__.cpython-312.pyc +0 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/config.cpython-312.pyc +0 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-312.pyc +0 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/custom_graph_pass.cpython-312.pyc +0 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/standalone_compile.cpython-312.pyc +0 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_operators.cpython-312.pyc +0 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__init__.py +0 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py +296 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py +321 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py +150 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py +149 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py +109 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py +0 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic.py +315 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py +339 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py +119 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py +95 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/__init__.py +0 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py +31 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp +443 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/block_analysis.py +175 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/common.py +2691 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp.py +0 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_bmm_template.py +262 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_flex_attention_template.py +1081 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_gemm_template.py +1777 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_grouped_gemm_template.py +500 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py +2011 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template.py +138 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template_kernel.py +597 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_utils.py +776 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py +0 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +878 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_gpu.py +717 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_mps.py +99 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpu_device_op_overrides.py +27 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__init__.py +0 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +293 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_env.py +45 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py +674 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_template.py +318 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_cache.py +105 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py +0 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py +240 -0
- .venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py +411 -0
.gitattributes
CHANGED
|
@@ -1825,3 +1825,180 @@ illustrious_generated/e3bdc40b2136.png filter=lfs diff=lfs merge=lfs -text
|
|
| 1825 |
illustrious_generated/6ea1f2330bb2.png filter=lfs diff=lfs merge=lfs -text
|
| 1826 |
illustrious_generated/a08733bfce6d.png filter=lfs diff=lfs merge=lfs -text
|
| 1827 |
illustrious_generated/0aec7eb07d4f.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1825 |
illustrious_generated/6ea1f2330bb2.png filter=lfs diff=lfs merge=lfs -text
|
| 1826 |
illustrious_generated/a08733bfce6d.png filter=lfs diff=lfs merge=lfs -text
|
| 1827 |
illustrious_generated/0aec7eb07d4f.png filter=lfs diff=lfs merge=lfs -text
|
| 1828 |
+
illustrious_generated/699b6be34d1c.png filter=lfs diff=lfs merge=lfs -text
|
| 1829 |
+
illustrious_generated/b8fe03330562.png filter=lfs diff=lfs merge=lfs -text
|
| 1830 |
+
illustrious_generated/efb6fe0239e2.png filter=lfs diff=lfs merge=lfs -text
|
| 1831 |
+
illustrious_generated/a00e6e3ca42d.png filter=lfs diff=lfs merge=lfs -text
|
| 1832 |
+
illustrious_generated/4c39593e352f.png filter=lfs diff=lfs merge=lfs -text
|
| 1833 |
+
illustrious_generated/daf79f3a5a47.png filter=lfs diff=lfs merge=lfs -text
|
| 1834 |
+
illustrious_generated/b66afd053661.png filter=lfs diff=lfs merge=lfs -text
|
| 1835 |
+
illustrious_generated/8a594acaada9.png filter=lfs diff=lfs merge=lfs -text
|
| 1836 |
+
illustrious_generated/532b5ad4f3b4.png filter=lfs diff=lfs merge=lfs -text
|
| 1837 |
+
illustrious_generated/aa2403d0dbde.png filter=lfs diff=lfs merge=lfs -text
|
| 1838 |
+
illustrious_generated/ffa981380c46.png filter=lfs diff=lfs merge=lfs -text
|
| 1839 |
+
illustrious_generated/60f785243e60.png filter=lfs diff=lfs merge=lfs -text
|
| 1840 |
+
illustrious_generated/776bf5e2d2a1.png filter=lfs diff=lfs merge=lfs -text
|
| 1841 |
+
illustrious_generated/5ab3056565d7.png filter=lfs diff=lfs merge=lfs -text
|
| 1842 |
+
illustrious_generated/bde193a32ad5.png filter=lfs diff=lfs merge=lfs -text
|
| 1843 |
+
illustrious_generated/58b9009a8433.png filter=lfs diff=lfs merge=lfs -text
|
| 1844 |
+
illustrious_generated/db68dd8cb2cf.png filter=lfs diff=lfs merge=lfs -text
|
| 1845 |
+
illustrious_generated/79a4d58eda5c.png filter=lfs diff=lfs merge=lfs -text
|
| 1846 |
+
illustrious_generated/6c2bec98397a.png filter=lfs diff=lfs merge=lfs -text
|
| 1847 |
+
illustrious_generated/56b23fd84d9e.png filter=lfs diff=lfs merge=lfs -text
|
| 1848 |
+
illustrious_generated/e38a1e2e958f.png filter=lfs diff=lfs merge=lfs -text
|
| 1849 |
+
illustrious_generated/4014bca6d153.png filter=lfs diff=lfs merge=lfs -text
|
| 1850 |
+
illustrious_generated/754db85fe85f.png filter=lfs diff=lfs merge=lfs -text
|
| 1851 |
+
illustrious_generated/0a4d2753b7c7.png filter=lfs diff=lfs merge=lfs -text
|
| 1852 |
+
illustrious_generated/a0e55ddbf74c.png filter=lfs diff=lfs merge=lfs -text
|
| 1853 |
+
illustrious_generated/c472bc9db7e5.png filter=lfs diff=lfs merge=lfs -text
|
| 1854 |
+
illustrious_generated/07f25b42b8d7.png filter=lfs diff=lfs merge=lfs -text
|
| 1855 |
+
illustrious_generated/a317dff0fe2e.png filter=lfs diff=lfs merge=lfs -text
|
| 1856 |
+
illustrious_generated/1ca2c247cfef.png filter=lfs diff=lfs merge=lfs -text
|
| 1857 |
+
illustrious_generated/92e013a48101.png filter=lfs diff=lfs merge=lfs -text
|
| 1858 |
+
illustrious_generated/53d7fbefadd0.png filter=lfs diff=lfs merge=lfs -text
|
| 1859 |
+
illustrious_generated/0d1561551252.png filter=lfs diff=lfs merge=lfs -text
|
| 1860 |
+
illustrious_generated/13d5cc66d420.png filter=lfs diff=lfs merge=lfs -text
|
| 1861 |
+
illustrious_generated/7429db9846f3.png filter=lfs diff=lfs merge=lfs -text
|
| 1862 |
+
illustrious_generated/49222c07a03f.png filter=lfs diff=lfs merge=lfs -text
|
| 1863 |
+
illustrious_generated/e526540fdbb8.png filter=lfs diff=lfs merge=lfs -text
|
| 1864 |
+
illustrious_generated/0e42a3a7d296.png filter=lfs diff=lfs merge=lfs -text
|
| 1865 |
+
illustrious_generated/29259007fbbf.png filter=lfs diff=lfs merge=lfs -text
|
| 1866 |
+
illustrious_generated/e189b3abf6d2.png filter=lfs diff=lfs merge=lfs -text
|
| 1867 |
+
illustrious_generated/4af1ebb4488c.png filter=lfs diff=lfs merge=lfs -text
|
| 1868 |
+
illustrious_generated/1755674c7bc1.png filter=lfs diff=lfs merge=lfs -text
|
| 1869 |
+
illustrious_generated/9a7e94351f39.png filter=lfs diff=lfs merge=lfs -text
|
| 1870 |
+
illustrious_generated/548d5f09851f.png filter=lfs diff=lfs merge=lfs -text
|
| 1871 |
+
illustrious_generated/8d3200b9e787.png filter=lfs diff=lfs merge=lfs -text
|
| 1872 |
+
illustrious_generated/a3a89fac4fcf.png filter=lfs diff=lfs merge=lfs -text
|
| 1873 |
+
illustrious_generated/e4bf1a6be7c6.png filter=lfs diff=lfs merge=lfs -text
|
| 1874 |
+
illustrious_generated/37c414606b7d.png filter=lfs diff=lfs merge=lfs -text
|
| 1875 |
+
illustrious_generated/c4033b9e6115.png filter=lfs diff=lfs merge=lfs -text
|
| 1876 |
+
illustrious_generated/29c45b6176ff.png filter=lfs diff=lfs merge=lfs -text
|
| 1877 |
+
illustrious_generated/e729db57f47c.png filter=lfs diff=lfs merge=lfs -text
|
| 1878 |
+
illustrious_generated/96a62fa83b20.png filter=lfs diff=lfs merge=lfs -text
|
| 1879 |
+
illustrious_generated/6a51d82ae8bc.png filter=lfs diff=lfs merge=lfs -text
|
| 1880 |
+
illustrious_generated/4632a958366a.png filter=lfs diff=lfs merge=lfs -text
|
| 1881 |
+
illustrious_generated/47aa86a18b5a.png filter=lfs diff=lfs merge=lfs -text
|
| 1882 |
+
illustrious_generated/d1995ff196e3.png filter=lfs diff=lfs merge=lfs -text
|
| 1883 |
+
illustrious_generated/e5ef86680557.png filter=lfs diff=lfs merge=lfs -text
|
| 1884 |
+
illustrious_generated/bd545030c487.png filter=lfs diff=lfs merge=lfs -text
|
| 1885 |
+
illustrious_generated/cef6d8bbd520.png filter=lfs diff=lfs merge=lfs -text
|
| 1886 |
+
illustrious_generated/6b23dba7be7d.png filter=lfs diff=lfs merge=lfs -text
|
| 1887 |
+
illustrious_generated/8d3ab41288cb.png filter=lfs diff=lfs merge=lfs -text
|
| 1888 |
+
illustrious_generated/e31935899f08.png filter=lfs diff=lfs merge=lfs -text
|
| 1889 |
+
illustrious_generated/bc6707dcf1f2.png filter=lfs diff=lfs merge=lfs -text
|
| 1890 |
+
illustrious_generated/b3a626cfe029.png filter=lfs diff=lfs merge=lfs -text
|
| 1891 |
+
illustrious_generated/3503bc2ba8ab.png filter=lfs diff=lfs merge=lfs -text
|
| 1892 |
+
illustrious_generated/ce38aa6d180d.png filter=lfs diff=lfs merge=lfs -text
|
| 1893 |
+
illustrious_generated/6035404a25b1.png filter=lfs diff=lfs merge=lfs -text
|
| 1894 |
+
illustrious_generated/9b1bb21a46e8.png filter=lfs diff=lfs merge=lfs -text
|
| 1895 |
+
illustrious_generated/a4eb52675c75.png filter=lfs diff=lfs merge=lfs -text
|
| 1896 |
+
illustrious_generated/a7acf95bd41f.png filter=lfs diff=lfs merge=lfs -text
|
| 1897 |
+
illustrious_generated/c636dce88f49.png filter=lfs diff=lfs merge=lfs -text
|
| 1898 |
+
illustrious_generated/8ef3560e5a42.png filter=lfs diff=lfs merge=lfs -text
|
| 1899 |
+
illustrious_generated/4f7953e3c61d.png filter=lfs diff=lfs merge=lfs -text
|
| 1900 |
+
illustrious_generated/a9d16660f9be.png filter=lfs diff=lfs merge=lfs -text
|
| 1901 |
+
illustrious_generated/5028ea7283a1.png filter=lfs diff=lfs merge=lfs -text
|
| 1902 |
+
illustrious_generated/259e79d12cba.png filter=lfs diff=lfs merge=lfs -text
|
| 1903 |
+
illustrious_generated/aa08fc4ab5f7.png filter=lfs diff=lfs merge=lfs -text
|
| 1904 |
+
illustrious_generated/a048f190d616.png filter=lfs diff=lfs merge=lfs -text
|
| 1905 |
+
illustrious_generated/64d5e838ff32.png filter=lfs diff=lfs merge=lfs -text
|
| 1906 |
+
illustrious_generated/f8b258d2d426.png filter=lfs diff=lfs merge=lfs -text
|
| 1907 |
+
illustrious_generated/73e5f02dca91.png filter=lfs diff=lfs merge=lfs -text
|
| 1908 |
+
illustrious_generated/176f37a6cd7f.png filter=lfs diff=lfs merge=lfs -text
|
| 1909 |
+
illustrious_generated/75bc8bf6da61.png filter=lfs diff=lfs merge=lfs -text
|
| 1910 |
+
illustrious_generated/831ac5f05b4e.png filter=lfs diff=lfs merge=lfs -text
|
| 1911 |
+
illustrious_generated/6061d4cce84a.png filter=lfs diff=lfs merge=lfs -text
|
| 1912 |
+
illustrious_generated/1970ccc17731.png filter=lfs diff=lfs merge=lfs -text
|
| 1913 |
+
illustrious_generated/46d7a6324536.png filter=lfs diff=lfs merge=lfs -text
|
| 1914 |
+
illustrious_generated/593a5b377fb6.png filter=lfs diff=lfs merge=lfs -text
|
| 1915 |
+
illustrious_generated/9e588e17280b.png filter=lfs diff=lfs merge=lfs -text
|
| 1916 |
+
illustrious_generated/f086d40d76b7.png filter=lfs diff=lfs merge=lfs -text
|
| 1917 |
+
illustrious_generated/e5b4630753fb.png filter=lfs diff=lfs merge=lfs -text
|
| 1918 |
+
illustrious_generated/bf19c6825d9d.png filter=lfs diff=lfs merge=lfs -text
|
| 1919 |
+
illustrious_generated/ad90d731a0fa.png filter=lfs diff=lfs merge=lfs -text
|
| 1920 |
+
illustrious_generated/4c7a209b8887.png filter=lfs diff=lfs merge=lfs -text
|
| 1921 |
+
illustrious_generated/b52834e10fd4.png filter=lfs diff=lfs merge=lfs -text
|
| 1922 |
+
illustrious_generated/e0fa2a6e6b42.png filter=lfs diff=lfs merge=lfs -text
|
| 1923 |
+
illustrious_generated/b770edd1abe0.png filter=lfs diff=lfs merge=lfs -text
|
| 1924 |
+
illustrious_generated/bf8fcd7e3ebb.png filter=lfs diff=lfs merge=lfs -text
|
| 1925 |
+
illustrious_generated/b2a469095584.png filter=lfs diff=lfs merge=lfs -text
|
| 1926 |
+
illustrious_generated/4af582043c66.png filter=lfs diff=lfs merge=lfs -text
|
| 1927 |
+
illustrious_generated/666f1a84cb80.png filter=lfs diff=lfs merge=lfs -text
|
| 1928 |
+
illustrious_generated/c87663be02a7.png filter=lfs diff=lfs merge=lfs -text
|
| 1929 |
+
illustrious_generated/5fa4fb280d76.png filter=lfs diff=lfs merge=lfs -text
|
| 1930 |
+
illustrious_generated/c65bd94bf971.png filter=lfs diff=lfs merge=lfs -text
|
| 1931 |
+
illustrious_generated/a3f618970fb7.png filter=lfs diff=lfs merge=lfs -text
|
| 1932 |
+
illustrious_generated/c588cc386611.png filter=lfs diff=lfs merge=lfs -text
|
| 1933 |
+
illustrious_generated/e412e3f1af16.png filter=lfs diff=lfs merge=lfs -text
|
| 1934 |
+
illustrious_generated/acf002111806.png filter=lfs diff=lfs merge=lfs -text
|
| 1935 |
+
illustrious_generated/3dbcf74b53f1.png filter=lfs diff=lfs merge=lfs -text
|
| 1936 |
+
illustrious_generated/149a82673fe9.png filter=lfs diff=lfs merge=lfs -text
|
| 1937 |
+
illustrious_generated/4536122bec42.png filter=lfs diff=lfs merge=lfs -text
|
| 1938 |
+
illustrious_generated/25df61f584ba.png filter=lfs diff=lfs merge=lfs -text
|
| 1939 |
+
illustrious_generated/5803e88c45f3.png filter=lfs diff=lfs merge=lfs -text
|
| 1940 |
+
illustrious_generated/d5f24e4f2554.png filter=lfs diff=lfs merge=lfs -text
|
| 1941 |
+
illustrious_generated/57264915c81a.png filter=lfs diff=lfs merge=lfs -text
|
| 1942 |
+
illustrious_generated/b5a6eaf580ce.png filter=lfs diff=lfs merge=lfs -text
|
| 1943 |
+
illustrious_generated/18f1f90221cc.png filter=lfs diff=lfs merge=lfs -text
|
| 1944 |
+
illustrious_generated/798df3adbd88.png filter=lfs diff=lfs merge=lfs -text
|
| 1945 |
+
illustrious_generated/617ce74ffafa.png filter=lfs diff=lfs merge=lfs -text
|
| 1946 |
+
illustrious_generated/1dfed5dbd442.png filter=lfs diff=lfs merge=lfs -text
|
| 1947 |
+
illustrious_generated/243fd5bd7679.png filter=lfs diff=lfs merge=lfs -text
|
| 1948 |
+
illustrious_generated/2433395d6d43.png filter=lfs diff=lfs merge=lfs -text
|
| 1949 |
+
illustrious_generated/96ab8498c7a6.png filter=lfs diff=lfs merge=lfs -text
|
| 1950 |
+
illustrious_generated/4a574003bbc9.png filter=lfs diff=lfs merge=lfs -text
|
| 1951 |
+
illustrious_generated/d209e317052f.png filter=lfs diff=lfs merge=lfs -text
|
| 1952 |
+
illustrious_generated/fb9ed29c63d7.png filter=lfs diff=lfs merge=lfs -text
|
| 1953 |
+
illustrious_generated/ece3750ad6a6.png filter=lfs diff=lfs merge=lfs -text
|
| 1954 |
+
illustrious_generated/690e7719d4bf.png filter=lfs diff=lfs merge=lfs -text
|
| 1955 |
+
illustrious_generated/639cf9531d64.png filter=lfs diff=lfs merge=lfs -text
|
| 1956 |
+
illustrious_generated/56e66190e7ce.png filter=lfs diff=lfs merge=lfs -text
|
| 1957 |
+
illustrious_generated/ceb937e0801e.png filter=lfs diff=lfs merge=lfs -text
|
| 1958 |
+
illustrious_generated/ed0e6ab9592e.png filter=lfs diff=lfs merge=lfs -text
|
| 1959 |
+
illustrious_generated/312f15890343.png filter=lfs diff=lfs merge=lfs -text
|
| 1960 |
+
illustrious_generated/d21efb5f15d1.png filter=lfs diff=lfs merge=lfs -text
|
| 1961 |
+
illustrious_generated/b01a98afee9a.png filter=lfs diff=lfs merge=lfs -text
|
| 1962 |
+
illustrious_generated/3cbd0d9a36a7.png filter=lfs diff=lfs merge=lfs -text
|
| 1963 |
+
illustrious_generated/6994acf3db4f.png filter=lfs diff=lfs merge=lfs -text
|
| 1964 |
+
illustrious_generated/b5ee5bce0095.png filter=lfs diff=lfs merge=lfs -text
|
| 1965 |
+
illustrious_generated/e150c196c1de.png filter=lfs diff=lfs merge=lfs -text
|
| 1966 |
+
illustrious_generated/d0d5e2ddb402.png filter=lfs diff=lfs merge=lfs -text
|
| 1967 |
+
illustrious_generated/b102a265b15d.png filter=lfs diff=lfs merge=lfs -text
|
| 1968 |
+
illustrious_generated/b0202f93d233.png filter=lfs diff=lfs merge=lfs -text
|
| 1969 |
+
illustrious_generated/14cde88e8c02.png filter=lfs diff=lfs merge=lfs -text
|
| 1970 |
+
illustrious_generated/30f0d4a3b6b6.png filter=lfs diff=lfs merge=lfs -text
|
| 1971 |
+
illustrious_generated/33aeb710e683.png filter=lfs diff=lfs merge=lfs -text
|
| 1972 |
+
illustrious_generated/71485a12c097.png filter=lfs diff=lfs merge=lfs -text
|
| 1973 |
+
illustrious_generated/2c3ca96d94bc.png filter=lfs diff=lfs merge=lfs -text
|
| 1974 |
+
illustrious_generated/a630cf5674f3.png filter=lfs diff=lfs merge=lfs -text
|
| 1975 |
+
illustrious_generated/163359d65694.png filter=lfs diff=lfs merge=lfs -text
|
| 1976 |
+
illustrious_generated/2f054f26c97c.png filter=lfs diff=lfs merge=lfs -text
|
| 1977 |
+
illustrious_generated/af2a10d42166.png filter=lfs diff=lfs merge=lfs -text
|
| 1978 |
+
illustrious_generated/4d7d6abe842c.png filter=lfs diff=lfs merge=lfs -text
|
| 1979 |
+
illustrious_generated/62fb4f26c8d0.png filter=lfs diff=lfs merge=lfs -text
|
| 1980 |
+
illustrious_generated/8bc741c00933.png filter=lfs diff=lfs merge=lfs -text
|
| 1981 |
+
illustrious_generated/f3021039e6df.png filter=lfs diff=lfs merge=lfs -text
|
| 1982 |
+
illustrious_generated/7ed913ac6954.png filter=lfs diff=lfs merge=lfs -text
|
| 1983 |
+
illustrious_generated/a9ab451bd4c9.png filter=lfs diff=lfs merge=lfs -text
|
| 1984 |
+
illustrious_generated/49c90b238c77.png filter=lfs diff=lfs merge=lfs -text
|
| 1985 |
+
illustrious_generated/1371e80e6f16.png filter=lfs diff=lfs merge=lfs -text
|
| 1986 |
+
illustrious_generated/f284fb62c9bf.png filter=lfs diff=lfs merge=lfs -text
|
| 1987 |
+
illustrious_generated/71e9866c5494.png filter=lfs diff=lfs merge=lfs -text
|
| 1988 |
+
illustrious_generated/abf601407cb0.png filter=lfs diff=lfs merge=lfs -text
|
| 1989 |
+
illustrious_generated/83da1d18e129.png filter=lfs diff=lfs merge=lfs -text
|
| 1990 |
+
illustrious_generated/484afd596fd0.png filter=lfs diff=lfs merge=lfs -text
|
| 1991 |
+
illustrious_generated/4c0f1f03dcd2.png filter=lfs diff=lfs merge=lfs -text
|
| 1992 |
+
illustrious_generated/2008ee2e6f66.png filter=lfs diff=lfs merge=lfs -text
|
| 1993 |
+
illustrious_generated/48fcc54ce758.png filter=lfs diff=lfs merge=lfs -text
|
| 1994 |
+
illustrious_generated/4769b078890d.png filter=lfs diff=lfs merge=lfs -text
|
| 1995 |
+
illustrious_generated/b716f3a23f11.png filter=lfs diff=lfs merge=lfs -text
|
| 1996 |
+
illustrious_generated/8e740eadaf18.png filter=lfs diff=lfs merge=lfs -text
|
| 1997 |
+
illustrious_generated/88e259bbe761.png filter=lfs diff=lfs merge=lfs -text
|
| 1998 |
+
illustrious_generated/476617799f8d.png filter=lfs diff=lfs merge=lfs -text
|
| 1999 |
+
illustrious_generated/1bb92008067a.png filter=lfs diff=lfs merge=lfs -text
|
| 2000 |
+
illustrious_generated/38cdd5aae9bb.png filter=lfs diff=lfs merge=lfs -text
|
| 2001 |
+
illustrious_generated/5e8555d59ee2.png filter=lfs diff=lfs merge=lfs -text
|
| 2002 |
+
illustrious_generated/f3f41fd3689e.png filter=lfs diff=lfs merge=lfs -text
|
| 2003 |
+
illustrious_generated/ca3147d568a3.png filter=lfs diff=lfs merge=lfs -text
|
| 2004 |
+
illustrious_generated/8e375573fe30.png filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (28.2 kB). View file
|
|
|
.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-312.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
.venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (192 Bytes). View file
|
|
|
.venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/python.cpython-312.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (41.4 kB). View file
|
|
|
.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-312.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/custom_graph_pass.cpython-312.pyc
ADDED
|
Binary file (5.43 kB). View file
|
|
|
.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/standalone_compile.cpython-312.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_operators.cpython-312.pyc
ADDED
|
Binary file (2.05 kB). View file
|
|
|
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: B950
|
| 2 |
+
# fmt: off
|
| 3 |
+
# This file was generated by AutoHeuristic. Do not modify it manually!
|
| 4 |
+
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 8 |
+
AHContext,
|
| 9 |
+
AHMetadata,
|
| 10 |
+
Choice,
|
| 11 |
+
)
|
| 12 |
+
from torch._inductor.autoheuristic.learnedheuristic_interface import (
|
| 13 |
+
LearnedHeuristicDecision,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MMRankingA100(LearnedHeuristicDecision):
|
| 18 |
+
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
self.choices: List[Choice] = []
|
| 21 |
+
self.fill_choices()
|
| 22 |
+
|
| 23 |
+
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
|
| 24 |
+
return (
|
| 25 |
+
metadata.name == self.get_name()
|
| 26 |
+
and metadata.shared_memory == 166912
|
| 27 |
+
and str(metadata.device_capa) == "(8, 0)"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def get_confidence_threshold(self) -> float:
|
| 31 |
+
return 0.0
|
| 32 |
+
|
| 33 |
+
def get_choice(self, idx: int) -> Optional[str]:
|
| 34 |
+
if idx < len(self.choices):
|
| 35 |
+
return self.choices[idx]
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
def fill_choices(self) -> None:
|
| 39 |
+
self.choices.append('extern_mm')
|
| 40 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8')
|
| 41 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8')
|
| 42 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
|
| 43 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
|
| 44 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 45 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
|
| 46 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 47 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
|
| 48 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
|
| 49 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
|
| 50 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8')
|
| 51 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
|
| 52 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8')
|
| 53 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
|
| 54 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8')
|
| 55 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
|
| 56 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8')
|
| 57 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
|
| 58 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8')
|
| 59 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
|
| 60 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8')
|
| 61 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
|
| 62 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
|
| 63 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 64 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
|
| 65 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
|
| 66 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8')
|
| 67 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
|
| 68 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
|
| 69 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 70 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
|
| 71 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 72 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
|
| 73 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
|
| 74 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 75 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8')
|
| 76 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 77 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4')
|
| 78 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8')
|
| 79 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8')
|
| 80 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
|
| 81 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8')
|
| 82 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
|
| 83 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8')
|
| 84 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
|
| 85 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8')
|
| 86 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2')
|
| 87 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8')
|
| 88 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
|
| 89 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
|
| 90 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
|
| 91 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
|
| 92 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
|
| 93 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
|
| 94 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2')
|
| 95 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8')
|
| 96 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 97 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
|
| 98 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 99 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
|
| 100 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 101 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
|
| 102 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 103 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
|
| 104 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
|
| 105 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 106 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
|
| 107 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8')
|
| 108 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8')
|
| 109 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
|
| 110 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8')
|
| 111 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
|
| 112 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
|
| 113 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8')
|
| 114 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
|
| 115 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8')
|
| 116 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 117 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
|
| 118 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8')
|
| 119 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
|
| 120 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8')
|
| 121 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 122 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
|
| 123 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
|
| 124 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
|
| 125 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 126 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 127 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
|
| 128 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 129 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
|
| 130 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
|
| 131 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1')
|
| 132 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
|
| 133 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
|
| 134 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2')
|
| 135 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2')
|
| 136 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2')
|
| 137 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
|
| 138 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 139 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 140 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 141 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 142 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1')
|
| 143 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2')
|
| 144 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 145 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 146 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 147 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
|
| 148 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 149 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=2')
|
| 150 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 151 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
|
| 152 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
|
| 153 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 154 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2')
|
| 155 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
|
| 156 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 157 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
|
| 158 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
|
| 159 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 160 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2')
|
| 161 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
|
| 162 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2')
|
| 163 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
|
| 164 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
|
| 165 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 166 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
|
| 167 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
|
| 168 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2')
|
| 169 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
|
| 170 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 171 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
|
| 172 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4')
|
| 173 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 174 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4')
|
| 175 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4')
|
| 176 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4')
|
| 177 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4')
|
| 178 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4')
|
| 179 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 180 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
|
| 181 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
|
| 182 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
|
| 183 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 184 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 185 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 186 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
|
| 187 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4')
|
| 188 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
|
| 189 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
|
| 190 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
|
| 191 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
|
| 192 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
|
| 193 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8')
|
| 194 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
|
| 195 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
|
| 196 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 197 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
|
| 198 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
|
| 199 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
|
| 200 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
|
| 201 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 202 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
|
| 203 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 204 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 205 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 206 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
|
| 207 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4')
|
| 208 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
|
| 209 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
|
| 210 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
|
| 211 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
|
| 212 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
|
| 213 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
|
| 214 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
|
| 215 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
|
| 216 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
|
| 217 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
|
| 218 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
|
| 219 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 220 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
|
| 221 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 222 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
|
| 223 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 224 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 225 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4')
|
| 226 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
|
| 227 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4')
|
| 228 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
|
| 229 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
|
| 230 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
|
| 231 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4')
|
| 232 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
|
| 233 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 234 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
|
| 235 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
|
| 236 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
|
| 237 |
+
|
| 238 |
+
def get_name(self) -> str:
|
| 239 |
+
return 'mm'
|
| 240 |
+
|
| 241 |
+
def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]:
|
| 242 |
+
if context.get_value('arith_intensity') <= 52.6245059967041:
|
| 243 |
+
if context.get_value('n') <= 34.0:
|
| 244 |
+
if context.get_value('n') <= 18.0:
|
| 245 |
+
if context.get_value('k*n') <= 312.0:
|
| 246 |
+
return [(0.093, 12), (0.081, 16), (0.081, 148), (0.070, 10), (0.070, 17), (0.070, 149), (0.070, 151), (0.070, 150), (0.070, 14), (0.058, 11), (0.058, 15), (0.058, 13), (0.058, 122), (0.047, 121), (0.035, 123), (0.012, 92)]
|
| 247 |
+
else:
|
| 248 |
+
if context.get_value('k') <= 40.0:
|
| 249 |
+
return [(0.083, 42), (0.083, 46), (0.083, 44), (0.083, 40), (0.083, 128), (0.067, 45), (0.067, 43), (0.067, 41), (0.067, 169), (0.067, 171), (0.067, 168), (0.067, 129), (0.067, 170), (0.033, 103), (0.017, 121)]
|
| 250 |
+
else:
|
| 251 |
+
return [(0.112, 137), (0.104, 136), (0.101, 0), (0.081, 1), (0.073, 135), (0.069, 67), (0.066, 187), (0.058, 41), (0.050, 71), (0.046, 68), (0.046, 70), (0.031, 44), (0.027, 43), (0.027, 170), (0.019, 189), (0.019, 188), (0.015, 169), (0.015, 171), (0.012, 115), (0.012, 168), (0.012, 69), (0.004, 103)]
|
| 252 |
+
else:
|
| 253 |
+
if context.get_value('mat1_stride_0') <= 20.0:
|
| 254 |
+
return [(0.069, 0), (0.059, 157), (0.059, 22), (0.059, 153), (0.059, 155), (0.059, 25), (0.059, 23), (0.059, 19), (0.044, 21), (0.044, 18), (0.044, 152), (0.044, 158), (0.044, 154), (0.044, 156), (0.044, 20), (0.044, 124), (0.044, 24), (0.030, 125), (0.029, 126), (0.015, 97), (0.015, 95), (0.015, 96), (0.010, 2), (0.010, 75)]
|
| 255 |
+
else:
|
| 256 |
+
if context.get_value('k') <= 68.0:
|
| 257 |
+
return [(0.087, 72), (0.087, 74), (0.087, 73), (0.086, 76), (0.077, 75), (0.067, 192), (0.058, 190), (0.048, 47), (0.048, 193), (0.048, 49), (0.048, 51), (0.048, 191), (0.038, 53), (0.019, 133), (0.019, 50), (0.019, 175), (0.019, 172), (0.019, 48), (0.019, 174), (0.010, 173), (0.010, 177), (0.010, 52), (0.010, 54), (0.010, 178), (0.010, 176)]
|
| 258 |
+
else:
|
| 259 |
+
return [(0.154, 52), (0.154, 72), (0.102, 75), (0.087, 49), (0.087, 73), (0.086, 51), (0.057, 176), (0.045, 2), (0.038, 191), (0.038, 178), (0.038, 190), (0.029, 173), (0.029, 76), (0.026, 138), (0.013, 139), (0.013, 140), (0.003, 0)]
|
| 260 |
+
else:
|
| 261 |
+
if context.get_value('k') <= 35.0:
|
| 262 |
+
if context.get_value('k') <= 18.0:
|
| 263 |
+
if context.get_value('m*n') <= 19505152.0:
|
| 264 |
+
return [(0.151, 159), (0.140, 160), (0.129, 164), (0.055, 127), (0.051, 29), (0.044, 161), (0.044, 147), (0.040, 146), (0.040, 31), (0.037, 145), (0.026, 28), (0.022, 90), (0.022, 93), (0.022, 94), (0.022, 100), (0.022, 125), (0.022, 158), (0.022, 157), (0.011, 87), (0.011, 88), (0.011, 89), (0.011, 91), (0.011, 95), (0.011, 96), (0.011, 98), (0.011, 99)]
|
| 265 |
+
else:
|
| 266 |
+
return [(0.069, 7), (0.069, 5), (0.067, 147), (0.066, 8), (0.061, 145), (0.058, 146), (0.052, 124), (0.049, 29), (0.049, 159), (0.046, 31), (0.043, 157), (0.041, 9), (0.041, 4), (0.040, 6), (0.035, 164), (0.035, 160), (0.026, 158), (0.017, 125), (0.017, 28), (0.017, 32), (0.017, 162), (0.017, 27), (0.017, 30), (0.017, 161), (0.009, 33), (0.009, 26), (0.009, 163), (0.006, 0)]
|
| 267 |
+
else:
|
| 268 |
+
if context.get_value('n') <= 68.0:
|
| 269 |
+
return [(0.101, 182), (0.101, 59), (0.088, 57), (0.076, 184), (0.076, 61), (0.076, 179), (0.076, 62), (0.076, 58), (0.063, 180), (0.063, 60), (0.051, 56), (0.050, 181), (0.025, 130), (0.025, 177), (0.025, 183), (0.013, 178), (0.013, 55)]
|
| 270 |
+
else:
|
| 271 |
+
return [(0.089, 180), (0.079, 60), (0.066, 35), (0.066, 181), (0.066, 38), (0.066, 58), (0.066, 179), (0.066, 57), (0.062, 184), (0.053, 37), (0.044, 166), (0.040, 55), (0.040, 39), (0.040, 36), (0.040, 165), (0.040, 167), (0.027, 177), (0.027, 34), (0.022, 159)]
|
| 272 |
+
else:
|
| 273 |
+
if context.get_value('m*n') <= 309760.0:
|
| 274 |
+
return [(0.298, 0), (0.097, 140), (0.080, 83), (0.072, 86), (0.044, 84), (0.036, 178), (0.036, 117), (0.036, 82), (0.032, 120), (0.032, 85), (0.028, 119), (0.024, 130), (0.024, 109), (0.020, 108), (0.020, 118), (0.012, 104), (0.012, 116), (0.012, 141), (0.012, 144), (0.008, 105), (0.008, 106), (0.008, 111), (0.008, 114), (0.008, 107), (0.008, 132), (0.004, 101), (0.004, 102), (0.004, 110), (0.004, 112), (0.004, 113), (0.004, 131)]
|
| 275 |
+
else:
|
| 276 |
+
if context.get_value('n') <= 72.0:
|
| 277 |
+
return [(0.227, 77), (0.118, 78), (0.102, 194), (0.086, 80), (0.059, 57), (0.054, 81), (0.049, 196), (0.048, 197), (0.048, 59), (0.043, 79), (0.032, 195), (0.027, 180), (0.022, 3), (0.021, 141), (0.016, 60), (0.016, 142), (0.011, 183), (0.011, 0), (0.011, 144)]
|
| 278 |
+
else:
|
| 279 |
+
return [(0.140, 186), (0.132, 185), (0.109, 63), (0.085, 65), (0.078, 37), (0.077, 35), (0.062, 197), (0.047, 194), (0.046, 165), (0.046, 57), (0.039, 78), (0.039, 79), (0.039, 66), (0.039, 64), (0.016, 195), (0.008, 159)]
|
| 280 |
+
else:
|
| 281 |
+
if str(context.get_value('using_tf32')) != 'False':
|
| 282 |
+
if context.get_value('m*n') <= 815360.0:
|
| 283 |
+
if context.get_value('k') <= 1184.0:
|
| 284 |
+
return [(0.218, 140), (0.205, 0), (0.154, 144), (0.115, 141), (0.051, 185), (0.051, 104), (0.039, 78), (0.038, 116), (0.026, 165), (0.026, 130), (0.026, 178), (0.013, 57), (0.013, 195), (0.013, 167), (0.013, 186)]
|
| 285 |
+
else:
|
| 286 |
+
return [(0.901, 0), (0.030, 144), (0.030, 134), (0.016, 3), (0.006, 78), (0.006, 77), (0.002, 57), (0.002, 194), (0.002, 59), (0.002, 60), (0.002, 143)]
|
| 287 |
+
else:
|
| 288 |
+
if context.get_value('arith_intensity') <= 187.23922729492188:
|
| 289 |
+
if context.get_value('mat1_stride_0') <= 198.0:
|
| 290 |
+
return [(0.273, 63), (0.158, 37), (0.152, 35), (0.127, 57), (0.097, 165), (0.053, 185), (0.031, 0), (0.028, 64), (0.014, 60), (0.014, 78), (0.009, 55), (0.008, 134), (0.005, 34), (0.005, 167), (0.005, 179), (0.005, 65), (0.005, 66), (0.005, 186), (0.005, 194), (0.002, 166)]
|
| 291 |
+
else:
|
| 292 |
+
return [(0.296, 63), (0.235, 0), (0.132, 64), (0.074, 37), (0.069, 78), (0.051, 185), (0.051, 35), (0.030, 57), (0.020, 77), (0.016, 194), (0.008, 66), (0.007, 65), (0.003, 3), (0.003, 165), (0.003, 141), (0.001, 134), (0.001, 166)]
|
| 293 |
+
else:
|
| 294 |
+
return [(0.405, 0), (0.246, 37), (0.177, 63), (0.145, 35), (0.005, 185), (0.005, 65), (0.005, 64), (0.004, 57), (0.003, 66), (0.002, 165), (0.001, 78), (0.001, 55)]
|
| 295 |
+
else:
|
| 296 |
+
return [(0.357, 0), (0.112, 165), (0.101, 57), (0.094, 179), (0.086, 64), (0.074, 167), (0.067, 60), (0.064, 159), (0.033, 35), (0.007, 195), (0.002, 180), (0.001, 34), (0.001, 166), (0.001, 78)]
|
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: B950
|
| 2 |
+
# fmt: off
|
| 3 |
+
# This file was generated by AutoHeuristic. Do not modify it manually!
|
| 4 |
+
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 8 |
+
AHContext,
|
| 9 |
+
AHMetadata,
|
| 10 |
+
Choice,
|
| 11 |
+
)
|
| 12 |
+
from torch._inductor.autoheuristic.learnedheuristic_interface import (
|
| 13 |
+
LearnedHeuristicDecision,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MMRankingH100(LearnedHeuristicDecision):
|
| 18 |
+
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
self.choices: List[Choice] = []
|
| 21 |
+
self.fill_choices()
|
| 22 |
+
|
| 23 |
+
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
|
| 24 |
+
return (
|
| 25 |
+
metadata.name == self.get_name()
|
| 26 |
+
and metadata.shared_memory == 232448
|
| 27 |
+
and str(metadata.device_capa) == "(9, 0)"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def get_confidence_threshold(self) -> float:
|
| 31 |
+
return 0.0
|
| 32 |
+
|
| 33 |
+
def get_choice(self, idx: int) -> Optional[str]:
|
| 34 |
+
if idx < len(self.choices):
|
| 35 |
+
return self.choices[idx]
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
def fill_choices(self) -> None:
|
| 39 |
+
self.choices.append('extern_mm')
|
| 40 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8')
|
| 41 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8')
|
| 42 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
|
| 43 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
|
| 44 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 45 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
|
| 46 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 47 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
|
| 48 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
|
| 49 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
|
| 50 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8')
|
| 51 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
|
| 52 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8')
|
| 53 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
|
| 54 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8')
|
| 55 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
|
| 56 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8')
|
| 57 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
|
| 58 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8')
|
| 59 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
|
| 60 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
|
| 61 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
|
| 62 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 63 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
|
| 64 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8')
|
| 65 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
|
| 66 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
|
| 67 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 68 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
|
| 69 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 70 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
|
| 71 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
|
| 72 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 73 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8')
|
| 74 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 75 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4')
|
| 76 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8')
|
| 77 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
|
| 78 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8')
|
| 79 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
|
| 80 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8')
|
| 81 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
|
| 82 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8')
|
| 83 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
|
| 84 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8')
|
| 85 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2')
|
| 86 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8')
|
| 87 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
|
| 88 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
|
| 89 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
|
| 90 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
|
| 91 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
|
| 92 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
|
| 93 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2')
|
| 94 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8')
|
| 95 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 96 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
|
| 97 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 98 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
|
| 99 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 100 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
|
| 101 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 102 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
|
| 103 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
|
| 104 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 105 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
|
| 106 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8')
|
| 107 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8')
|
| 108 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
|
| 109 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8')
|
| 110 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
|
| 111 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
|
| 112 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8')
|
| 113 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
|
| 114 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8')
|
| 115 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 116 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
|
| 117 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8')
|
| 118 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
|
| 119 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8')
|
| 120 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 121 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
|
| 122 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
|
| 123 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
|
| 124 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
|
| 125 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 126 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
|
| 127 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 128 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
|
| 129 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 130 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
|
| 131 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
|
| 132 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
|
| 133 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=1')
|
| 134 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=1')
|
| 135 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1')
|
| 136 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
|
| 137 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
|
| 138 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2')
|
| 139 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2')
|
| 140 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2')
|
| 141 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
|
| 142 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
|
| 143 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
|
| 144 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 145 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 146 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
|
| 147 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 148 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 149 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
|
| 150 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=1')
|
| 151 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1')
|
| 152 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=2')
|
| 153 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2')
|
| 154 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
|
| 155 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 156 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 157 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 158 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
|
| 159 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
|
| 160 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 161 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 162 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2')
|
| 163 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
|
| 164 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 165 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 166 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2')
|
| 167 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
|
| 168 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2')
|
| 169 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
|
| 170 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
|
| 171 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 172 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
|
| 173 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
|
| 174 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2')
|
| 175 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
|
| 176 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
|
| 177 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
|
| 178 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=16_numstages=2_numwarps=2')
|
| 179 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4')
|
| 180 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 181 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4')
|
| 182 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4')
|
| 183 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4')
|
| 184 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4')
|
| 185 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4')
|
| 186 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 187 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
|
| 188 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
|
| 189 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 190 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 191 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 192 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
|
| 193 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4')
|
| 194 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
|
| 195 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
|
| 196 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
|
| 197 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
|
| 198 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
|
| 199 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
|
| 200 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 201 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
|
| 202 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
|
| 203 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
|
| 204 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
|
| 205 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 206 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
|
| 207 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 208 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 209 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 210 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
|
| 211 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4')
|
| 212 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
|
| 213 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
|
| 214 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
|
| 215 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
|
| 216 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
|
| 217 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
|
| 218 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
|
| 219 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
|
| 220 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
|
| 221 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
|
| 222 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
|
| 223 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 224 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
|
| 225 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 226 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
|
| 227 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 228 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 229 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4')
|
| 230 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
|
| 231 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4')
|
| 232 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
|
| 233 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
|
| 234 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
|
| 235 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4')
|
| 236 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
|
| 237 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 238 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
|
| 239 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
|
| 240 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
|
| 241 |
+
|
| 242 |
+
def get_name(self) -> str:
|
| 243 |
+
return 'mm'
|
| 244 |
+
|
| 245 |
+
def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]:
|
| 246 |
+
if context.get_value('arith_intensity') <= 29.89772129058838:
|
| 247 |
+
if context.get_value('n') <= 34.0:
|
| 248 |
+
if context.get_value('n') <= 18.0:
|
| 249 |
+
if context.get_value('k*n') <= 432.0:
|
| 250 |
+
if context.get_value('arith_intensity') <= 7.8700292110443115:
|
| 251 |
+
return [(0.098, 128), (0.098, 129), (0.098, 127), (0.073, 14), (0.073, 16), (0.073, 12), (0.073, 154), (0.073, 156), (0.073, 157), (0.073, 155), (0.049, 10), (0.049, 94), (0.049, 95), (0.048, 96)]
|
| 252 |
+
else:
|
| 253 |
+
return [(0.091, 154), (0.073, 10), (0.073, 15), (0.073, 13), (0.073, 11), (0.073, 17), (0.073, 16), (0.073, 14), (0.073, 12), (0.055, 127), (0.054, 157), (0.054, 156), (0.054, 155), (0.036, 129), (0.036, 128), (0.018, 41), (0.018, 43)]
|
| 254 |
+
else:
|
| 255 |
+
if context.get_value('k') <= 40.0:
|
| 256 |
+
return [(0.070, 39), (0.069, 45), (0.069, 41), (0.069, 43), (0.069, 111), (0.069, 112), (0.056, 38), (0.056, 40), (0.056, 42), (0.056, 44), (0.056, 174), (0.056, 173), (0.056, 175), (0.056, 134), (0.056, 172), (0.056, 135), (0.014, 154), (0.014, 127)]
|
| 257 |
+
else:
|
| 258 |
+
return [(0.147, 144), (0.119, 143), (0.087, 142), (0.083, 0), (0.073, 191), (0.059, 69), (0.050, 67), (0.046, 70), (0.041, 1), (0.036, 174), (0.032, 43), (0.032, 123), (0.028, 40), (0.027, 42), (0.027, 173), (0.023, 175), (0.018, 66), (0.014, 192), (0.014, 193), (0.014, 139), (0.014, 68), (0.014, 127)]
|
| 259 |
+
else:
|
| 260 |
+
if context.get_value('mat1_stride_0') <= 40.0:
|
| 261 |
+
if context.get_value('mat1_stride_0') <= 20.0:
|
| 262 |
+
return [(0.109, 23), (0.109, 21), (0.109, 20), (0.088, 0), (0.087, 131), (0.066, 18), (0.065, 130), (0.065, 132), (0.065, 159), (0.065, 160), (0.065, 161), (0.065, 158), (0.022, 22), (0.022, 19)]
|
| 263 |
+
else:
|
| 264 |
+
return [(0.065, 46), (0.064, 52), (0.064, 50), (0.064, 48), (0.064, 51), (0.064, 49), (0.064, 47), (0.064, 53), (0.064, 181), (0.064, 177), (0.064, 179), (0.064, 176), (0.038, 130), (0.038, 136), (0.026, 182), (0.026, 178), (0.026, 180), (0.026, 137), (0.025, 158), (0.013, 114), (0.013, 113)]
|
| 265 |
+
else:
|
| 266 |
+
if context.get_value('mat1_stride_0') <= 68.0:
|
| 267 |
+
return [(0.138, 140), (0.125, 195), (0.100, 71), (0.100, 74), (0.100, 196), (0.100, 194), (0.100, 197), (0.075, 75), (0.062, 72), (0.062, 73), (0.012, 180), (0.012, 51), (0.012, 182)]
|
| 268 |
+
else:
|
| 269 |
+
return [(0.124, 180), (0.124, 182), (0.114, 75), (0.103, 74), (0.093, 51), (0.093, 71), (0.072, 72), (0.062, 194), (0.052, 145), (0.052, 195), (0.021, 48), (0.021, 50), (0.021, 47), (0.020, 124), (0.010, 147), (0.010, 146), (0.010, 46)]
|
| 270 |
+
else:
|
| 271 |
+
if context.get_value('k') <= 18.0:
|
| 272 |
+
if context.get_value('m*k') <= 528.0:
|
| 273 |
+
return [(0.097, 88), (0.087, 92), (0.077, 90), (0.058, 105), (0.058, 103), (0.058, 104), (0.058, 99), (0.058, 100), (0.058, 106), (0.058, 93), (0.057, 91), (0.057, 97), (0.057, 98), (0.057, 101), (0.048, 102), (0.029, 87), (0.029, 89)]
|
| 274 |
+
else:
|
| 275 |
+
if context.get_value('n') <= 80.0:
|
| 276 |
+
return [(0.057, 161), (0.057, 130), (0.057, 24), (0.056, 164), (0.056, 163), (0.056, 166), (0.056, 168), (0.056, 30), (0.056, 28), (0.056, 26), (0.056, 25), (0.056, 27), (0.056, 29), (0.056, 31), (0.042, 131), (0.028, 99), (0.028, 101), (0.028, 100), (0.028, 167), (0.028, 165), (0.028, 133)]
|
| 277 |
+
else:
|
| 278 |
+
return [(0.110, 164), (0.108, 163), (0.106, 168), (0.069, 161), (0.066, 151), (0.060, 152), (0.055, 165), (0.050, 27), (0.050, 29), (0.048, 131), (0.043, 153), (0.037, 133), (0.037, 130), (0.028, 8), (0.028, 5), (0.027, 7), (0.026, 26), (0.016, 162), (0.012, 9), (0.007, 4), (0.005, 100), (0.005, 6), (0.005, 24)]
|
| 279 |
+
else:
|
| 280 |
+
if context.get_value('k') <= 36.0:
|
| 281 |
+
if context.get_value('n') <= 68.0:
|
| 282 |
+
return [(0.097, 184), (0.097, 56), (0.086, 186), (0.086, 183), (0.086, 188), (0.086, 58), (0.086, 60), (0.065, 54), (0.043, 187), (0.043, 185), (0.043, 57), (0.043, 61), (0.032, 55), (0.032, 130), (0.032, 59), (0.011, 181), (0.011, 163), (0.011, 136), (0.011, 138)]
|
| 283 |
+
else:
|
| 284 |
+
return [(0.117, 184), (0.117, 170), (0.117, 169), (0.107, 183), (0.106, 188), (0.075, 181), (0.064, 130), (0.064, 56), (0.053, 171), (0.032, 57), (0.032, 59), (0.032, 185), (0.011, 163), (0.011, 32), (0.011, 37), (0.011, 34), (0.011, 33), (0.011, 35), (0.011, 36), (0.011, 54)]
|
| 285 |
+
else:
|
| 286 |
+
if context.get_value('mat2_stride_0') <= 384.0:
|
| 287 |
+
return [(0.244, 0), (0.061, 76), (0.061, 79), (0.030, 3), (0.030, 183), (0.030, 189), (0.030, 187), (0.030, 64), (0.030, 190), (0.030, 62), (0.030, 198), (0.030, 201), (0.030, 77), (0.030, 200), (0.030, 80), (0.030, 199), (0.030, 78), (0.030, 184), (0.020, 86), (0.020, 84), (0.020, 120), (0.020, 81), (0.020, 121), (0.020, 85), (0.020, 122), (0.010, 83), (0.010, 118), (0.010, 119), (0.010, 82)]
|
| 288 |
+
else:
|
| 289 |
+
return [(0.274, 83), (0.171, 86), (0.152, 0), (0.071, 85), (0.061, 125), (0.050, 84), (0.020, 109), (0.020, 117), (0.020, 81), (0.020, 118), (0.020, 121), (0.020, 108), (0.020, 115), (0.020, 116), (0.010, 110), (0.010, 120), (0.010, 103), (0.010, 107), (0.010, 119), (0.010, 122)]
|
| 290 |
+
else:
|
| 291 |
+
if context.get_value('arith_intensity') <= 56.995582580566406:
|
| 292 |
+
if context.get_value('n') <= 68.0:
|
| 293 |
+
if context.get_value('k*n') <= 4448.0:
|
| 294 |
+
if context.get_value('m*n') <= 29626368.0:
|
| 295 |
+
return [(0.107, 198), (0.107, 200), (0.107, 201), (0.107, 199), (0.106, 76), (0.106, 79), (0.064, 197), (0.063, 56), (0.043, 184), (0.043, 187), (0.042, 80), (0.042, 77), (0.042, 183), (0.021, 78)]
|
| 296 |
+
else:
|
| 297 |
+
return [(0.073, 201), (0.073, 198), (0.073, 200), (0.073, 199), (0.073, 197), (0.073, 56), (0.073, 58), (0.073, 79), (0.073, 76), (0.072, 59), (0.072, 78), (0.072, 77), (0.072, 80), (0.018, 184), (0.018, 55), (0.018, 54)]
|
| 298 |
+
else:
|
| 299 |
+
if context.get_value('k') <= 348.0:
|
| 300 |
+
return [(0.206, 76), (0.183, 77), (0.169, 198), (0.160, 199), (0.053, 59), (0.046, 56), (0.038, 3), (0.030, 148), (0.030, 58), (0.030, 187), (0.023, 184), (0.015, 0), (0.008, 55), (0.008, 54)]
|
| 301 |
+
else:
|
| 302 |
+
return [(0.146, 198), (0.145, 199), (0.145, 148), (0.126, 0), (0.084, 76), (0.084, 77), (0.042, 80), (0.042, 79), (0.021, 149), (0.021, 150), (0.021, 3), (0.014, 46), (0.014, 74), (0.014, 75), (0.014, 124), (0.014, 194), (0.014, 195), (0.007, 145), (0.007, 146), (0.007, 2), (0.007, 72), (0.007, 147), (0.007, 71)]
|
| 303 |
+
else:
|
| 304 |
+
if context.get_value('m') <= 3264.0:
|
| 305 |
+
return [(0.247, 147), (0.115, 197), (0.066, 199), (0.066, 201), (0.066, 198), (0.049, 0), (0.049, 169), (0.049, 171), (0.033, 140), (0.033, 125), (0.033, 114), (0.016, 126), (0.016, 183), (0.016, 184), (0.016, 185), (0.016, 182), (0.016, 188), (0.016, 78), (0.016, 148), (0.016, 138), (0.016, 77), (0.016, 56), (0.016, 59)]
|
| 306 |
+
else:
|
| 307 |
+
if context.get_value('k') <= 62.5:
|
| 308 |
+
return [(0.226, 190), (0.226, 189), (0.122, 62), (0.122, 64), (0.055, 77), (0.055, 78), (0.037, 198), (0.036, 201), (0.036, 33), (0.024, 163), (0.018, 56), (0.018, 35), (0.018, 169), (0.006, 171)]
|
| 309 |
+
else:
|
| 310 |
+
return [(0.162, 35), (0.118, 33), (0.096, 189), (0.096, 190), (0.088, 169), (0.074, 62), (0.073, 56), (0.066, 171), (0.051, 198), (0.051, 201), (0.044, 59), (0.037, 64), (0.029, 63), (0.007, 0), (0.007, 77)]
|
| 311 |
+
else:
|
| 312 |
+
if context.get_value('m*n') <= 1097728.0:
|
| 313 |
+
return [(0.403, 0), (0.179, 141), (0.134, 150), (0.086, 147), (0.051, 148), (0.048, 3), (0.024, 189), (0.020, 199), (0.017, 64), (0.010, 65), (0.010, 77), (0.007, 114), (0.003, 138), (0.003, 59), (0.003, 182)]
|
| 314 |
+
else:
|
| 315 |
+
if context.get_value('m*n') <= 3244032.0:
|
| 316 |
+
return [(0.295, 189), (0.176, 64), (0.157, 65), (0.090, 0), (0.069, 62), (0.059, 63), (0.046, 77), (0.039, 169), (0.023, 199), (0.020, 35), (0.013, 33), (0.010, 171), (0.003, 141)]
|
| 317 |
+
else:
|
| 318 |
+
if context.get_value('n') <= 136.0:
|
| 319 |
+
return [(0.197, 189), (0.197, 63), (0.161, 77), (0.157, 62), (0.061, 33), (0.044, 65), (0.039, 35), (0.039, 64), (0.030, 169), (0.026, 0), (0.017, 199), (0.017, 148), (0.009, 56), (0.004, 3)]
|
| 320 |
+
else:
|
| 321 |
+
return [(0.460, 0), (0.145, 62), (0.138, 63), (0.081, 35), (0.047, 33), (0.043, 189), (0.023, 64), (0.018, 77), (0.013, 169), (0.009, 65), (0.009, 56), (0.005, 32), (0.005, 59), (0.002, 183), (0.002, 163)]
|
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: B950
|
| 2 |
+
# fmt: off
|
| 3 |
+
# This file was generated by AutoHeuristic. Do not modify it manually!
|
| 4 |
+
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 8 |
+
AHContext,
|
| 9 |
+
AHMetadata,
|
| 10 |
+
Choice,
|
| 11 |
+
)
|
| 12 |
+
from torch._inductor.autoheuristic.learnedheuristic_interface import (
|
| 13 |
+
LearnedHeuristicDecision,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MixedMMA100(LearnedHeuristicDecision):
|
| 18 |
+
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
self.choices: List[Choice] = []
|
| 21 |
+
self.fill_choices()
|
| 22 |
+
|
| 23 |
+
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
|
| 24 |
+
return (
|
| 25 |
+
metadata.name == self.get_name()
|
| 26 |
+
and metadata.shared_memory == 166912
|
| 27 |
+
and str(metadata.device_capa) == "(8, 0)"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def get_confidence_threshold(self) -> float:
|
| 31 |
+
return 0.0
|
| 32 |
+
|
| 33 |
+
def get_choice(self, idx: int) -> Optional[str]:
|
| 34 |
+
if idx < len(self.choices):
|
| 35 |
+
return self.choices[idx]
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
def fill_choices(self) -> None:
|
| 39 |
+
self.choices.append('extern_fallback_mixed_mm')
|
| 40 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 41 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 42 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 43 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
|
| 44 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
|
| 45 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 46 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4')
|
| 47 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8')
|
| 48 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 49 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 50 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 51 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
|
| 52 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 53 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 54 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 55 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 56 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 57 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
|
| 58 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 59 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 60 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 61 |
+
|
| 62 |
+
def get_name(self) -> str:
|
| 63 |
+
return 'mixed_mm'
|
| 64 |
+
|
| 65 |
+
def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]:
|
| 66 |
+
if str(context.get_value('1LEQmLEQ16')) != 'True':
|
| 67 |
+
if context.get_value('m') <= 32.5:
|
| 68 |
+
if context.get_value('n') <= 6976.0:
|
| 69 |
+
if context.get_value('n') <= 3520.0:
|
| 70 |
+
if context.get_value('m*n') <= 37632.0:
|
| 71 |
+
return None
|
| 72 |
+
else:
|
| 73 |
+
return [(1.000, 13)]
|
| 74 |
+
else:
|
| 75 |
+
if context.get_value('m*k') <= 452352.0:
|
| 76 |
+
return [(0.590, 13), (0.256, 8), (0.103, 7), (0.051, 11)]
|
| 77 |
+
else:
|
| 78 |
+
return [(0.778, 8), (0.222, 13)]
|
| 79 |
+
else:
|
| 80 |
+
if context.get_value('k*n') <= 102776832.0:
|
| 81 |
+
if context.get_value('n') <= 14656.0:
|
| 82 |
+
return [(1.000, 11)]
|
| 83 |
+
else:
|
| 84 |
+
return [(0.889, 11), (0.111, 13)]
|
| 85 |
+
else:
|
| 86 |
+
return [(1.000, 11)]
|
| 87 |
+
else:
|
| 88 |
+
if context.get_value('m*n') <= 446464.0:
|
| 89 |
+
if context.get_value('m*n') <= 223424.0:
|
| 90 |
+
if context.get_value('mat1_stride_0') <= 3968.0:
|
| 91 |
+
return None
|
| 92 |
+
else:
|
| 93 |
+
return None
|
| 94 |
+
else:
|
| 95 |
+
if context.get_value('m*n') <= 346112.0:
|
| 96 |
+
return [(0.960, 16), (0.040, 7)]
|
| 97 |
+
else:
|
| 98 |
+
return [(0.750, 16), (0.136, 14), (0.114, 7)]
|
| 99 |
+
else:
|
| 100 |
+
if str(context.get_value('33LEQmLEQ64')) != 'True':
|
| 101 |
+
if context.get_value('n') <= 6976.0:
|
| 102 |
+
return [(1.000, 14)]
|
| 103 |
+
else:
|
| 104 |
+
return [(0.753, 2), (0.222, 1), (0.015, 7), (0.007, 16), (0.004, 12)]
|
| 105 |
+
else:
|
| 106 |
+
if context.get_value('n') <= 13888.0:
|
| 107 |
+
return [(0.710, 14), (0.275, 21), (0.014, 12)]
|
| 108 |
+
else:
|
| 109 |
+
return [(0.374, 19), (0.339, 20), (0.106, 21), (0.101, 16), (0.066, 17), (0.009, 14), (0.004, 18)]
|
| 110 |
+
else:
|
| 111 |
+
if context.get_value('n') <= 3520.0:
|
| 112 |
+
if context.get_value('arith_intensity') <= 3.994754433631897:
|
| 113 |
+
if str(context.get_value('mat2_dtype')) != 'torch.uint8':
|
| 114 |
+
if context.get_value('m*k') <= 18944.0:
|
| 115 |
+
return [(0.577, 5), (0.423, 6)]
|
| 116 |
+
else:
|
| 117 |
+
return [(0.988, 5), (0.012, 6)]
|
| 118 |
+
else:
|
| 119 |
+
if context.get_value('arith_intensity') <= 2.9899919033050537:
|
| 120 |
+
return None
|
| 121 |
+
else:
|
| 122 |
+
return None
|
| 123 |
+
else:
|
| 124 |
+
if context.get_value('arith_intensity') <= 7.956453561782837:
|
| 125 |
+
if context.get_value('k*n') <= 9244032.0:
|
| 126 |
+
return [(0.822, 5), (0.178, 6)]
|
| 127 |
+
else:
|
| 128 |
+
return [(0.977, 5), (0.023, 0)]
|
| 129 |
+
else:
|
| 130 |
+
if context.get_value('m*k') <= 978944.0:
|
| 131 |
+
return [(1.000, 5)]
|
| 132 |
+
else:
|
| 133 |
+
return [(0.971, 5), (0.029, 0)]
|
| 134 |
+
else:
|
| 135 |
+
if context.get_value('n') <= 13632.0:
|
| 136 |
+
if context.get_value('n') <= 6976.0:
|
| 137 |
+
return [(1.000, 6)]
|
| 138 |
+
else:
|
| 139 |
+
if context.get_value('k') <= 3968.0:
|
| 140 |
+
return [(0.617, 3), (0.111, 5), (0.099, 7), (0.086, 9), (0.062, 6), (0.025, 8)]
|
| 141 |
+
else:
|
| 142 |
+
return [(0.779, 8), (0.119, 5), (0.053, 7), (0.035, 6), (0.013, 3)]
|
| 143 |
+
else:
|
| 144 |
+
if context.get_value('k*n') <= 39518208.0:
|
| 145 |
+
return [(0.385, 4), (0.327, 3), (0.192, 6), (0.038, 7), (0.038, 10), (0.019, 5)]
|
| 146 |
+
else:
|
| 147 |
+
if context.get_value('n') <= 20800.0:
|
| 148 |
+
return [(0.821, 6), (0.121, 7), (0.029, 4), (0.014, 5), (0.007, 3), (0.007, 8)]
|
| 149 |
+
else:
|
| 150 |
+
return [(0.530, 7), (0.386, 6), (0.046, 8), (0.021, 3), (0.015, 4), (0.002, 5)]
|
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: B950
|
| 2 |
+
# fmt: off
|
| 3 |
+
# This file was generated by AutoHeuristic. Do not modify it manually!
|
| 4 |
+
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 8 |
+
AHContext,
|
| 9 |
+
AHMetadata,
|
| 10 |
+
Choice,
|
| 11 |
+
)
|
| 12 |
+
from torch._inductor.autoheuristic.learnedheuristic_interface import (
|
| 13 |
+
LearnedHeuristicDecision,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MixedMMH100(LearnedHeuristicDecision):
|
| 18 |
+
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
self.choices: List[Choice] = []
|
| 21 |
+
self.fill_choices()
|
| 22 |
+
|
| 23 |
+
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
|
| 24 |
+
return (
|
| 25 |
+
metadata.name == self.get_name()
|
| 26 |
+
and metadata.shared_memory == 232448
|
| 27 |
+
and str(metadata.device_capa) == "(9, 0)"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def get_confidence_threshold(self) -> float:
|
| 31 |
+
return 0.0
|
| 32 |
+
|
| 33 |
+
def get_choice(self, idx: int) -> Optional[str]:
|
| 34 |
+
if idx < len(self.choices):
|
| 35 |
+
return self.choices[idx]
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
def fill_choices(self) -> None:
|
| 39 |
+
self.choices.append('extern_fallback_mixed_mm')
|
| 40 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 41 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 42 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 43 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 44 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
|
| 45 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
|
| 46 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 47 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4')
|
| 48 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8')
|
| 49 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 50 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 51 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 52 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
|
| 53 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 54 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
|
| 55 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 56 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 57 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 58 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 59 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
|
| 60 |
+
|
| 61 |
+
def get_name(self) -> str:
|
| 62 |
+
return 'mixed_mm'
|
| 63 |
+
|
| 64 |
+
def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]:
|
| 65 |
+
if context.get_value('arith_intensity') <= 15.988086223602295:
|
| 66 |
+
if context.get_value('n') <= 25280.0:
|
| 67 |
+
if context.get_value('n') <= 1344.0:
|
| 68 |
+
if context.get_value('mat1_stride_0') <= 7808.0:
|
| 69 |
+
return [(0.581, 7), (0.419, 6)]
|
| 70 |
+
else:
|
| 71 |
+
if context.get_value('m*n') <= 7680.0:
|
| 72 |
+
return [(0.875, 0), (0.125, 6)]
|
| 73 |
+
else:
|
| 74 |
+
return [(0.833, 0), (0.167, 7)]
|
| 75 |
+
else:
|
| 76 |
+
if context.get_value('n') <= 8512.0:
|
| 77 |
+
if str(context.get_value('mat2_dtype')) != 'torch.int8':
|
| 78 |
+
return [(0.763, 6), (0.237, 7)]
|
| 79 |
+
else:
|
| 80 |
+
return [(0.725, 7), (0.275, 6)]
|
| 81 |
+
else:
|
| 82 |
+
if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
|
| 83 |
+
return [(0.736, 7), (0.197, 9), (0.048, 6), (0.014, 8), (0.005, 10)]
|
| 84 |
+
else:
|
| 85 |
+
return [(0.473, 7), (0.398, 6), (0.097, 9), (0.032, 10)]
|
| 86 |
+
else:
|
| 87 |
+
if context.get_value('n') <= 42254.0:
|
| 88 |
+
if context.get_value('n') <= 33856.0:
|
| 89 |
+
if context.get_value('k*n') <= 68157440.0:
|
| 90 |
+
return [(0.370, 4), (0.370, 5), (0.074, 7), (0.074, 8), (0.074, 11), (0.037, 6)]
|
| 91 |
+
else:
|
| 92 |
+
return [(0.916, 8), (0.036, 7), (0.036, 9), (0.012, 4)]
|
| 93 |
+
else:
|
| 94 |
+
return [(0.659, 5), (0.341, 6)]
|
| 95 |
+
else:
|
| 96 |
+
if context.get_value('k*n') <= 326052992.0:
|
| 97 |
+
if context.get_value('n') <= 55232.0:
|
| 98 |
+
return [(0.571, 6), (0.321, 7), (0.036, 4), (0.036, 8), (0.036, 9)]
|
| 99 |
+
else:
|
| 100 |
+
return [(0.506, 6), (0.325, 8), (0.104, 7), (0.039, 5), (0.026, 9)]
|
| 101 |
+
else:
|
| 102 |
+
if context.get_value('n') <= 57024.0:
|
| 103 |
+
return [(0.462, 9), (0.385, 7), (0.115, 6), (0.038, 8)]
|
| 104 |
+
else:
|
| 105 |
+
return [(0.598, 8), (0.223, 9), (0.107, 6), (0.071, 7)]
|
| 106 |
+
else:
|
| 107 |
+
if context.get_value('m*n') <= 543936.0:
|
| 108 |
+
if str(context.get_value('17LEQmLEQ32')) != 'True':
|
| 109 |
+
if context.get_value('m*n') <= 262272.0:
|
| 110 |
+
if context.get_value('n') <= 1592.5:
|
| 111 |
+
return [(0.860, 0), (0.140, 9)]
|
| 112 |
+
else:
|
| 113 |
+
return None
|
| 114 |
+
else:
|
| 115 |
+
if context.get_value('m*k') <= 1294336.0:
|
| 116 |
+
return [(0.833, 17), (0.150, 18), (0.017, 15)]
|
| 117 |
+
else:
|
| 118 |
+
return [(0.917, 17), (0.083, 8)]
|
| 119 |
+
else:
|
| 120 |
+
if context.get_value('n') <= 12416.0:
|
| 121 |
+
if context.get_value('m*n') <= 43008.0:
|
| 122 |
+
return None
|
| 123 |
+
else:
|
| 124 |
+
return [(0.853, 14), (0.147, 9)]
|
| 125 |
+
else:
|
| 126 |
+
return [(0.625, 12), (0.375, 14)]
|
| 127 |
+
else:
|
| 128 |
+
if context.get_value('m') <= 32.5:
|
| 129 |
+
if context.get_value('mat2_stride_1') <= 6656.0:
|
| 130 |
+
if context.get_value('n') <= 69184.0:
|
| 131 |
+
return [(0.611, 12), (0.361, 14), (0.028, 13)]
|
| 132 |
+
else:
|
| 133 |
+
return [(1.000, 12)]
|
| 134 |
+
else:
|
| 135 |
+
if context.get_value('mat2_stride_1') <= 20864.0:
|
| 136 |
+
return [(1.000, 12)]
|
| 137 |
+
else:
|
| 138 |
+
return [(0.958, 12), (0.042, 9)]
|
| 139 |
+
else:
|
| 140 |
+
if context.get_value('m*n') <= 1085440.0:
|
| 141 |
+
if context.get_value('n') <= 9152.0:
|
| 142 |
+
return [(1.000, 18)]
|
| 143 |
+
else:
|
| 144 |
+
return [(0.780, 18), (0.160, 16), (0.060, 20)]
|
| 145 |
+
else:
|
| 146 |
+
if context.get_value('m') <= 67.0:
|
| 147 |
+
return [(0.650, 16), (0.203, 19), (0.122, 18), (0.016, 20), (0.008, 1)]
|
| 148 |
+
else:
|
| 149 |
+
return [(0.561, 3), (0.185, 16), (0.096, 20), (0.083, 19), (0.076, 2)]
|
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: B950
|
| 2 |
+
# fmt: off
|
| 3 |
+
# This file was generated by AutoHeuristic. Do not modify it manually!
|
| 4 |
+
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/
|
| 5 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL
|
| 6 |
+
from torch._inductor.autoheuristic.learnedheuristic_interface import (
|
| 7 |
+
LearnedHeuristicRegression,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PadMMA100(LearnedHeuristicRegression):
|
| 12 |
+
|
| 13 |
+
def __init__(self) -> None:
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
|
| 17 |
+
return (
|
| 18 |
+
metadata.name == self.get_name()
|
| 19 |
+
and metadata.shared_memory == 166912
|
| 20 |
+
and str(metadata.device_capa) == "(8, 0)"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def get_feedback(self, context: AHContext, choice: Choice) -> float:
|
| 24 |
+
context.context_dict[CHOICE_COL] = choice
|
| 25 |
+
return self.predict(context)
|
| 26 |
+
|
| 27 |
+
def get_confidence_threshold(self) -> float:
|
| 28 |
+
return 1.7025303314066
|
| 29 |
+
|
| 30 |
+
def get_name(self) -> str:
|
| 31 |
+
return 'pad_mm'
|
| 32 |
+
|
| 33 |
+
def predict(self, context: AHContext) -> float:
|
| 34 |
+
if str(context.get_value('choice')) != 'pad':
|
| 35 |
+
if str(context.get_value('using_tf32')) != 'False':
|
| 36 |
+
if context.get_value('m*n') <= 4171264.0:
|
| 37 |
+
if context.get_value('m*k') <= 3999308.0:
|
| 38 |
+
return 1.8751469764071178
|
| 39 |
+
else:
|
| 40 |
+
if str(context.get_value('n_multiple_32')) != 'True':
|
| 41 |
+
return 0.9117231355626345
|
| 42 |
+
else:
|
| 43 |
+
return 1.1607689608873861
|
| 44 |
+
else:
|
| 45 |
+
if str(context.get_value('n_multiple_2')) != 'True':
|
| 46 |
+
if str(context.get_value('using_tf32')) != 'True':
|
| 47 |
+
return 0.7430382200435992
|
| 48 |
+
else:
|
| 49 |
+
return 0.8531269794448678
|
| 50 |
+
else:
|
| 51 |
+
if str(context.get_value('k_multiple_2')) != 'True':
|
| 52 |
+
return 0.7577181972719917
|
| 53 |
+
else:
|
| 54 |
+
return 0.8977349440424219
|
| 55 |
+
else:
|
| 56 |
+
if context.get_value('m*n') <= 1299712.0:
|
| 57 |
+
return 1.1669723418995592
|
| 58 |
+
else:
|
| 59 |
+
if context.get_value('mat2_stride_1') <= 45217.5:
|
| 60 |
+
if context.get_value('m*n') <= 55884158.0:
|
| 61 |
+
return 1.0262769936909601
|
| 62 |
+
else:
|
| 63 |
+
return 1.0022677428470845
|
| 64 |
+
else:
|
| 65 |
+
if context.get_value('m') <= 18478.0:
|
| 66 |
+
return 1.1127066261894312
|
| 67 |
+
else:
|
| 68 |
+
return 1.0337740659894263
|
| 69 |
+
else:
|
| 70 |
+
if str(context.get_value('mat1_dtype')) != 'torch.float32':
|
| 71 |
+
if str(context.get_value('n_multiple_2')) != 'False':
|
| 72 |
+
if str(context.get_value('k_multiple_2')) != 'True':
|
| 73 |
+
if context.get_value('mat1_stride_0') <= 561.0:
|
| 74 |
+
return 1.2900382135142956
|
| 75 |
+
else:
|
| 76 |
+
return 1.5761737616057887
|
| 77 |
+
else:
|
| 78 |
+
if context.get_value('num_dims_needs_padding') <= 1.5:
|
| 79 |
+
return 1.0472263310239422
|
| 80 |
+
else:
|
| 81 |
+
return 1.1727673465762514
|
| 82 |
+
else:
|
| 83 |
+
if context.get_value('k') <= 28238.5:
|
| 84 |
+
if context.get_value('k/(m*n)') <= 0.00026227018679492176:
|
| 85 |
+
return 1.6770542505397175
|
| 86 |
+
else:
|
| 87 |
+
return 1.3974785435105923
|
| 88 |
+
else:
|
| 89 |
+
if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
|
| 90 |
+
return 1.3952699800111992
|
| 91 |
+
else:
|
| 92 |
+
return 1.5759286511628336
|
| 93 |
+
else:
|
| 94 |
+
if str(context.get_value('using_tf32')) != 'False':
|
| 95 |
+
if context.get_value('m*n') <= 14119424.0:
|
| 96 |
+
return 0.8875772670422478
|
| 97 |
+
else:
|
| 98 |
+
if str(context.get_value('mat2_innermost_needs_padding')) != 'True':
|
| 99 |
+
return 1.1467728924377265
|
| 100 |
+
else:
|
| 101 |
+
return 1.215842963532998
|
| 102 |
+
else:
|
| 103 |
+
if context.get_value('arith_intensity') <= 396.8774871826172:
|
| 104 |
+
return 0.89940161869551
|
| 105 |
+
else:
|
| 106 |
+
if context.get_value('mat2_stride_1') <= 45217.5:
|
| 107 |
+
return 0.9964328169353532
|
| 108 |
+
else:
|
| 109 |
+
return 0.9493479238294826
|
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Any, Callable, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 8 |
+
AHContext,
|
| 9 |
+
AHMetadata,
|
| 10 |
+
AHOperation,
|
| 11 |
+
Choice,
|
| 12 |
+
CHOICE_COL,
|
| 13 |
+
Feedback,
|
| 14 |
+
FEEDBACK_COL,
|
| 15 |
+
get_metadata_str_from_log,
|
| 16 |
+
)
|
| 17 |
+
from torch._inductor.autoheuristic.learned_heuristic_controller import (
|
| 18 |
+
LearnedHeuristicController,
|
| 19 |
+
)
|
| 20 |
+
from torch._inductor.ir import ChoiceCaller
|
| 21 |
+
from torch._inductor.runtime.runtime_utils import cache_dir
|
| 22 |
+
from torch._inductor.utils import get_gpu_shared_memory
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class LocalFeedback:
|
| 26 |
+
"""
|
| 27 |
+
To be able to collect data for a choice, a function providing feedback given a choice has to be provided.
|
| 28 |
+
LocalFeedback can be used when AutoHeuristic should immediately run the function to collect feedback for each choice
|
| 29 |
+
(see pad_mm.py, where the autotuning happens locally, for an example).
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, feedback_fn: Callable[[Choice], Feedback]) -> None:
|
| 33 |
+
self.feedback_fn = feedback_fn
|
| 34 |
+
|
| 35 |
+
def __call__(self, choice: Choice) -> Feedback:
|
| 36 |
+
return self.feedback_fn(choice)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class InconsistentMetadata(Exception):
|
| 40 |
+
"""
|
| 41 |
+
Exception that is thrown when AutoHeuristic tries to log data to a file where the metadata stored in the file does
|
| 42 |
+
not match the metadata it would store if the file didn't exist.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class AutoHeuristic:
|
| 47 |
+
"""
|
| 48 |
+
AutoHeuristic is a framework that allows one to collect data, learn a heuristic (i.e. a regression tree) and
|
| 49 |
+
generate the heuristic to code. This class allows one to collect data. The collected data can then be used to train
|
| 50 |
+
a heuristic (see torchgen/autoheuristic/).
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
collected_feedback: dict[Choice, Feedback]
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
fallback: Callable[[], Choice],
|
| 58 |
+
choices: list[Choice],
|
| 59 |
+
feedback: Optional[LocalFeedback],
|
| 60 |
+
context: AHContext,
|
| 61 |
+
name: str,
|
| 62 |
+
augment_context: Optional[list[AHOperation]] = None,
|
| 63 |
+
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
|
| 64 |
+
) -> None:
|
| 65 |
+
"""
|
| 66 |
+
Initializes an instance of the AutoHeuristic class.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
fallback: A callable that returns a Choice when the heuristic is unsure which choice to make, or
|
| 70 |
+
AutoHeuristic is in data collection mode.
|
| 71 |
+
choices: A list of possible choices the heuristic can make.
|
| 72 |
+
feedback: An instance of LocalFeedback that provides feedback for a given choice.
|
| 73 |
+
context: Context to store with each choice and feedback.
|
| 74 |
+
name: A string that identifies the heuristic.
|
| 75 |
+
augment_context: An optional list of AHOperation instances that augment the context.
|
| 76 |
+
precondition: A callable that returns a boolean indicating whether AutoHeuristic should run.
|
| 77 |
+
"""
|
| 78 |
+
self.fallback = fallback
|
| 79 |
+
self.choices = choices
|
| 80 |
+
self.feedback = feedback
|
| 81 |
+
self.context = context
|
| 82 |
+
self.name = name
|
| 83 |
+
self.collected_feedback = {}
|
| 84 |
+
self.augment_context = augment_context
|
| 85 |
+
self.metadata = AHMetadata(
|
| 86 |
+
get_gpu_shared_memory(),
|
| 87 |
+
torch.cuda.get_device_capability(),
|
| 88 |
+
self.choices,
|
| 89 |
+
self.name,
|
| 90 |
+
)
|
| 91 |
+
self.precondition = precondition
|
| 92 |
+
|
| 93 |
+
if not self.satisfies_precondition():
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
+
if torch._inductor.config.autoheuristic_log_path == "DEFAULT":
|
| 97 |
+
self.log_path = self.get_default_log_path()
|
| 98 |
+
else:
|
| 99 |
+
self.log_path = torch._inductor.config.autoheuristic_log_path
|
| 100 |
+
|
| 101 |
+
if torch._inductor.config.collect_autoheuristic(self.name):
|
| 102 |
+
if self.feedback is not None:
|
| 103 |
+
for choice in self.choices:
|
| 104 |
+
feedback_val = self.feedback(choice)
|
| 105 |
+
self.save_data(choice, feedback_val)
|
| 106 |
+
|
| 107 |
+
def satisfies_precondition(self) -> bool:
|
| 108 |
+
return self.precondition is None or self.precondition(
|
| 109 |
+
self.metadata, self.context
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def get_choice(self) -> Choice:
|
| 113 |
+
"""
|
| 114 |
+
Returns the chosen option based on the value of autoheuristic_use.
|
| 115 |
+
If self.name is one of the comma separated strings in autoheuristic_use,
|
| 116 |
+
it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
if not self.satisfies_precondition():
|
| 120 |
+
return self.fallback()
|
| 121 |
+
|
| 122 |
+
if torch._inductor.config.use_autoheuristic(self.name):
|
| 123 |
+
if self.augment_context is not None:
|
| 124 |
+
self.context.apply_operations(self.augment_context)
|
| 125 |
+
controller = LearnedHeuristicController(
|
| 126 |
+
self.metadata,
|
| 127 |
+
self.context,
|
| 128 |
+
)
|
| 129 |
+
decision = controller.get_decision()
|
| 130 |
+
if decision not in self.choices:
|
| 131 |
+
# TODO(AlnisM): We might want to allow this in the future
|
| 132 |
+
return self.fallback()
|
| 133 |
+
if decision is not None:
|
| 134 |
+
return decision
|
| 135 |
+
return self.fallback()
|
| 136 |
+
|
| 137 |
+
def get_top_k_choices(
|
| 138 |
+
self, top_k: int, always_included: Optional[list[str]] = None
|
| 139 |
+
) -> Optional[list[Choice]]:
|
| 140 |
+
if not self.satisfies_precondition():
|
| 141 |
+
return None
|
| 142 |
+
if torch._inductor.config.use_autoheuristic(self.name):
|
| 143 |
+
if self.augment_context is not None:
|
| 144 |
+
self.context.apply_operations(self.augment_context)
|
| 145 |
+
controller = LearnedHeuristicController(
|
| 146 |
+
self.metadata,
|
| 147 |
+
self.context,
|
| 148 |
+
)
|
| 149 |
+
choices = controller.get_decisions_ranked(top_k)
|
| 150 |
+
if choices is None:
|
| 151 |
+
return None
|
| 152 |
+
if always_included is not None:
|
| 153 |
+
for choice in always_included:
|
| 154 |
+
if choice not in choices:
|
| 155 |
+
choices.append(choice)
|
| 156 |
+
return choices
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
def get_collected_feedback(self, choice: Choice) -> Any:
|
| 160 |
+
return self.collected_feedback.get(choice, None)
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def get_device_identifier() -> str:
|
| 164 |
+
# a heuristic might work well for one GPU, but not for another
|
| 165 |
+
# we store the collected data per GPU model and learn a heuristic per GPU model
|
| 166 |
+
|
| 167 |
+
# TODO(AlnisM): just using the device name for now, but the same GPU model can have different names
|
| 168 |
+
device_name = torch.cuda.get_device_name().replace(" ", "_")
|
| 169 |
+
return device_name
|
| 170 |
+
|
| 171 |
+
def get_default_log_path(self) -> str:
|
| 172 |
+
device_name = self.get_device_identifier()
|
| 173 |
+
path = f"{cache_dir()}/autoheuristic/{device_name}/"
|
| 174 |
+
os.makedirs(path, exist_ok=True)
|
| 175 |
+
path += f"{self.name}.txt"
|
| 176 |
+
return path
|
| 177 |
+
|
| 178 |
+
def serialize_metadata(self) -> str:
|
| 179 |
+
metadata_dict = self.metadata.to_dict()
|
| 180 |
+
(
|
| 181 |
+
num_features,
|
| 182 |
+
cat_features,
|
| 183 |
+
) = self.context.get_numerical_and_categorical_features()
|
| 184 |
+
metadata_dict["numerical_features"] = num_features
|
| 185 |
+
metadata_dict["categorical_features"] = cat_features
|
| 186 |
+
return json.dumps(metadata_dict)
|
| 187 |
+
|
| 188 |
+
def save_data(self, choice: Choice, feedback_val: Feedback) -> None:
|
| 189 |
+
self.collected_feedback[choice] = feedback_val
|
| 190 |
+
log_path = self.log_path
|
| 191 |
+
|
| 192 |
+
lines = []
|
| 193 |
+
log_exists = os.path.exists(log_path)
|
| 194 |
+
if log_exists:
|
| 195 |
+
# if log already exists, make sure it is consistent
|
| 196 |
+
metadata = self.serialize_metadata()
|
| 197 |
+
existing_metadata = get_metadata_str_from_log(self.log_path)
|
| 198 |
+
if existing_metadata != metadata:
|
| 199 |
+
raise InconsistentMetadata(
|
| 200 |
+
"Given metadata does not match existing metadata"
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
lines.append(self.serialize_metadata())
|
| 204 |
+
feature_header = self.context.get_feature_names_csv()
|
| 205 |
+
header = feature_header + "," + CHOICE_COL + "," + FEEDBACK_COL
|
| 206 |
+
lines.append(header)
|
| 207 |
+
|
| 208 |
+
line = ""
|
| 209 |
+
feature_values = self.context.get_feature_values_csv()
|
| 210 |
+
line += feature_values + "," + choice + "," + str(feedback_val)
|
| 211 |
+
lines.append(line)
|
| 212 |
+
|
| 213 |
+
with open(log_path, "a") as f:
|
| 214 |
+
f.write("\n".join(lines) + "\n")
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class AutoHeuristicSelectAlgorithm(AutoHeuristic):
|
| 218 |
+
"""
|
| 219 |
+
AutoHeuristicSelectAlgorithm is a subclass of AutoHeuristic that allows one to collect data and learn a heuristic
|
| 220 |
+
when one wants to use AutoHeuristic for kernel choice selection.
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
fallback: Callable[[], Optional[ChoiceCaller]],
|
| 226 |
+
choices: list[ChoiceCaller],
|
| 227 |
+
input_nodes: list[Any],
|
| 228 |
+
context: AHContext,
|
| 229 |
+
name: str,
|
| 230 |
+
augment_context: Optional[list[AHOperation]] = None,
|
| 231 |
+
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
|
| 232 |
+
) -> None:
|
| 233 |
+
"""
|
| 234 |
+
The arguments choices, input_nodes and name have to match the ones used in the call to
|
| 235 |
+
autotune_select_algorithm(), e.g. if the following call is made
|
| 236 |
+
autotune_select_algorithm(name, choices, input_nodes, layout), the same name, choices and input_nodes
|
| 237 |
+
have to be used here.
|
| 238 |
+
"""
|
| 239 |
+
self.input_nodes = input_nodes
|
| 240 |
+
self.choicestr2choice: dict[str, ChoiceCaller] = {}
|
| 241 |
+
for choice in choices:
|
| 242 |
+
self.choicestr2choice[choice.autoheuristic_id()] = choice
|
| 243 |
+
choices_str = list(self.choicestr2choice.keys())
|
| 244 |
+
|
| 245 |
+
def fallback_str() -> str:
|
| 246 |
+
fallback_choice = fallback()
|
| 247 |
+
if fallback_choice is None:
|
| 248 |
+
# TODO: Find a nicer way to handle this
|
| 249 |
+
return "unsure"
|
| 250 |
+
return fallback_choice.autoheuristic_id()
|
| 251 |
+
|
| 252 |
+
super().__init__(
|
| 253 |
+
fallback_str,
|
| 254 |
+
choices_str,
|
| 255 |
+
None,
|
| 256 |
+
context,
|
| 257 |
+
name,
|
| 258 |
+
augment_context,
|
| 259 |
+
precondition,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
if (
|
| 263 |
+
torch._inductor.config.collect_autoheuristic(self.name)
|
| 264 |
+
and self.satisfies_precondition()
|
| 265 |
+
):
|
| 266 |
+
self.register_global_feedback(input_nodes, choices)
|
| 267 |
+
|
| 268 |
+
def register_global_feedback(
|
| 269 |
+
self, input_nodes: list[Any], choices: list[ChoiceCaller]
|
| 270 |
+
) -> None:
|
| 271 |
+
"""
|
| 272 |
+
Registers a callback in select_algorithm, which is called with the timing of each choice.
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
from torch._inductor.select_algorithm import (
|
| 276 |
+
add_feedback_saver,
|
| 277 |
+
create_inputs_key,
|
| 278 |
+
create_precompile_key,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
def store_global_feedback(
|
| 282 |
+
ah_inputs_key: str,
|
| 283 |
+
ah_precompile_key: str,
|
| 284 |
+
timings: dict[ChoiceCaller, float],
|
| 285 |
+
name: str,
|
| 286 |
+
input_nodes: list[Any],
|
| 287 |
+
choices: list[ChoiceCaller],
|
| 288 |
+
) -> None:
|
| 289 |
+
current_inputs_key = create_inputs_key(input_nodes)
|
| 290 |
+
if current_inputs_key != ah_inputs_key:
|
| 291 |
+
return
|
| 292 |
+
current_precompile_key = create_precompile_key(
|
| 293 |
+
name, current_inputs_key, choices
|
| 294 |
+
)
|
| 295 |
+
if current_precompile_key != ah_precompile_key:
|
| 296 |
+
return
|
| 297 |
+
for choice, time in timings.items():
|
| 298 |
+
self.save_data(choice.autoheuristic_id(), time)
|
| 299 |
+
|
| 300 |
+
inputs_key = create_inputs_key(input_nodes)
|
| 301 |
+
precompile_key = create_precompile_key(self.name, inputs_key, choices)
|
| 302 |
+
feedback_saver = partial(store_global_feedback, inputs_key, precompile_key)
|
| 303 |
+
add_feedback_saver(feedback_saver)
|
| 304 |
+
|
| 305 |
+
def get_choice_caller(self) -> Optional[ChoiceCaller]:
|
| 306 |
+
choice = self.get_choice()
|
| 307 |
+
return self.choicestr2choice.get(choice, None)
|
| 308 |
+
|
| 309 |
+
def get_top_k_choices_caller(
|
| 310 |
+
self, top_k: int, always_included: Optional[list[str]] = None
|
| 311 |
+
) -> Optional[list[ChoiceCaller]]:
|
| 312 |
+
choices = self.get_top_k_choices(top_k, always_included)
|
| 313 |
+
if choices is None:
|
| 314 |
+
return None
|
| 315 |
+
return [self.choicestr2choice[choice] for choice in choices]
|
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
from typing import Any, Callable
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
Feedback = float
|
| 8 |
+
Choice = str
|
| 9 |
+
Value = Any
|
| 10 |
+
|
| 11 |
+
CHOICE_COL = "choice"
|
| 12 |
+
FEEDBACK_COL = "feedback"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AHFeature:
|
| 16 |
+
"""
|
| 17 |
+
The context, that AutoHeuristic stores, is a list of features. AutoHeuristic needs to know whether a feature is
|
| 18 |
+
categorical (i.e., not a continuous variable) to learn a machine learning model.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, name: str, value: Value, is_categorical: bool = False) -> None:
|
| 22 |
+
self.name = name
|
| 23 |
+
self.value = value
|
| 24 |
+
self.is_categorical = is_categorical
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class AHOperation:
|
| 28 |
+
"""
|
| 29 |
+
AHOperation can be used to augment the data collected by AutoHeuristic.
|
| 30 |
+
One might for example store features like m, k, n, but also want to use
|
| 31 |
+
features like m*n, or k*n, to learn a heuristic. Instead of storing features
|
| 32 |
+
that can be created from the collected data, one can use AHOperation to
|
| 33 |
+
create new features from the collected data.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self, name: str, func: Callable[[Any], Value], is_categorical: bool = False
|
| 38 |
+
) -> None:
|
| 39 |
+
self.name = name
|
| 40 |
+
self.func = func
|
| 41 |
+
self.is_categorical = is_categorical
|
| 42 |
+
|
| 43 |
+
def apply_operation(self, data: Any) -> None:
|
| 44 |
+
data[self.name] = self.func(data)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class AHContext:
|
| 48 |
+
"""
|
| 49 |
+
This class is used to specify which information AutoHeuristic should store. For each choice, AutoHeursitic will
|
| 50 |
+
store the context and the collected feedback. The context could be something like the shape of a tensor, i.e.,
|
| 51 |
+
information that will help to learn a heuristic.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
features: list[AHFeature]
|
| 55 |
+
context_dict: dict[str, Value]
|
| 56 |
+
|
| 57 |
+
def __init__(self) -> None:
|
| 58 |
+
self.features = []
|
| 59 |
+
self.context_dict = {}
|
| 60 |
+
|
| 61 |
+
def add_feature(
|
| 62 |
+
self, name: str, value: Value, is_categorical: bool = False
|
| 63 |
+
) -> None:
|
| 64 |
+
self.features.append(AHFeature(name, value, is_categorical=is_categorical))
|
| 65 |
+
self.context_dict[name] = value
|
| 66 |
+
|
| 67 |
+
def get_numerical_and_categorical_features(self) -> tuple[list[str], list[str]]:
|
| 68 |
+
numerical_features = []
|
| 69 |
+
categorical_features = []
|
| 70 |
+
for feature in self.features:
|
| 71 |
+
if feature.is_categorical:
|
| 72 |
+
categorical_features.append(feature.name)
|
| 73 |
+
else:
|
| 74 |
+
numerical_features.append(feature.name)
|
| 75 |
+
|
| 76 |
+
return numerical_features, categorical_features
|
| 77 |
+
|
| 78 |
+
def get_feature_names_csv(self) -> str:
|
| 79 |
+
return ",".join(feature.name for feature in self.features)
|
| 80 |
+
|
| 81 |
+
def get_feature_values_csv(self) -> str:
|
| 82 |
+
return ",".join(str(feature.value) for feature in self.features)
|
| 83 |
+
|
| 84 |
+
def get_value(self, name: str) -> Value:
|
| 85 |
+
return self.context_dict[name]
|
| 86 |
+
|
| 87 |
+
def apply_operations(self, operations: list[AHOperation]) -> None:
|
| 88 |
+
for op in operations:
|
| 89 |
+
op.apply_operation(self.context_dict)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class AHMetadata:
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
shared_memory: Any,
|
| 96 |
+
device_capa: tuple[int, int],
|
| 97 |
+
choices: list[Choice],
|
| 98 |
+
name: str,
|
| 99 |
+
) -> None:
|
| 100 |
+
# use amount of shared_memory and device_capability to identify GPU
|
| 101 |
+
# TODO(AlnisM): there might be a better way to do this
|
| 102 |
+
self.shared_memory = shared_memory
|
| 103 |
+
self.device_capa = device_capa
|
| 104 |
+
self.choices = choices
|
| 105 |
+
self.name = name
|
| 106 |
+
|
| 107 |
+
def to_dict(self) -> dict[str, Value]:
|
| 108 |
+
return {
|
| 109 |
+
"shared_memory": self.shared_memory,
|
| 110 |
+
"device_capa": self.device_capa,
|
| 111 |
+
"name": self.name,
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def get_metadata_str_from_log(log_path: str) -> str:
|
| 116 |
+
with open(log_path, newline="") as file:
|
| 117 |
+
json_string = file.readline().strip()
|
| 118 |
+
return json_string
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def check_minsize(context: AHContext, minsize: int) -> bool:
|
| 122 |
+
return (
|
| 123 |
+
context.get_value("m") >= minsize
|
| 124 |
+
and context.get_value("k") >= minsize
|
| 125 |
+
and context.get_value("n") >= minsize
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def pad_mm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
|
| 130 |
+
if metadata.shared_memory == 166912 and metadata.device_capa == (8, 0):
|
| 131 |
+
# A100 precondition
|
| 132 |
+
return check_minsize(context, 512)
|
| 133 |
+
elif metadata.shared_memory == 232448 and metadata.device_capa == (9, 0):
|
| 134 |
+
# H100 precondition
|
| 135 |
+
return check_minsize(context, 768)
|
| 136 |
+
return True
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_mixedmm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
|
| 140 |
+
m = context.get_value("m")
|
| 141 |
+
k = context.get_value("k")
|
| 142 |
+
n = context.get_value("n")
|
| 143 |
+
if m > 128 or k < 1024 or n < 1024:
|
| 144 |
+
return False
|
| 145 |
+
mat1_iscontig = context.get_value("mat1_iscontig")
|
| 146 |
+
mat2_iscontig = context.get_value("mat2_iscontig")
|
| 147 |
+
return mat1_iscontig and not mat2_iscontig
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def get_mult_dims_ops() -> list[AHOperation]:
|
| 151 |
+
m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"])
|
| 152 |
+
m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"])
|
| 153 |
+
k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"])
|
| 154 |
+
return [m_times_k_op, m_times_n_op, k_times_n_op]
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_arith_intensity(data: Any) -> float:
|
| 158 |
+
m = data["m"]
|
| 159 |
+
k = data["k"]
|
| 160 |
+
n = data["n"]
|
| 161 |
+
if m == 0 or k == 0 or n == 0:
|
| 162 |
+
return 0.0
|
| 163 |
+
return m * k * n / (m * k + k * n + m * n)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def pad_mm_operations() -> list[AHOperation]:
|
| 167 |
+
mult_dims_ops = get_mult_dims_ops()
|
| 168 |
+
k_div_m_times_n_op = AHOperation(
|
| 169 |
+
"k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"])
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
def bfloat_perf_hit(data: Any) -> bool:
|
| 173 |
+
m = data["m"]
|
| 174 |
+
k = data["k"]
|
| 175 |
+
n = data["n"]
|
| 176 |
+
is_bfloat = str(data["mat1_dtype"]) == "torch.bfloat16"
|
| 177 |
+
return k > (m * 1024) and k > (n * 1024) and is_bfloat
|
| 178 |
+
|
| 179 |
+
bfloat_perf_hit_op = AHOperation(
|
| 180 |
+
"bfloat_perf_hit", bfloat_perf_hit, is_categorical=True
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
|
| 184 |
+
dims_need_padding_ops = get_dims_need_padding_ops()
|
| 185 |
+
dims_multiple_ops = get_dims_multiple_ops()
|
| 186 |
+
is_contig_ops = get_is_contig_ops()
|
| 187 |
+
|
| 188 |
+
ah_operations = mult_dims_ops + [
|
| 189 |
+
k_div_m_times_n_op,
|
| 190 |
+
bfloat_perf_hit_op,
|
| 191 |
+
arith_intensity_op,
|
| 192 |
+
]
|
| 193 |
+
ah_operations.extend(dims_need_padding_ops)
|
| 194 |
+
ah_operations.extend(dims_multiple_ops)
|
| 195 |
+
ah_operations.extend(is_contig_ops)
|
| 196 |
+
return ah_operations
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def between_op(data: Any, dim: str, lower: int, upper: int) -> bool:
|
| 200 |
+
return data[dim] >= lower and data[dim] <= upper
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def between_ops() -> list[AHOperation]:
|
| 204 |
+
dims = ["m", "k", "n"]
|
| 205 |
+
limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)]
|
| 206 |
+
ah_operations = []
|
| 207 |
+
for dim in dims:
|
| 208 |
+
for lower, upper in limits:
|
| 209 |
+
between_op_fn = functools.partial(
|
| 210 |
+
between_op, dim=dim, lower=lower, upper=upper
|
| 211 |
+
)
|
| 212 |
+
# using 'LEQ' instead of '<=' because '<=' cannot be exported to dot
|
| 213 |
+
between_op_name = f"{lower}LEQ{dim}LEQ{upper}"
|
| 214 |
+
ah_operations.append(
|
| 215 |
+
AHOperation(between_op_name, between_op_fn, is_categorical=True)
|
| 216 |
+
)
|
| 217 |
+
return ah_operations
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def pow2_op(data: Any, dim: str, exponent: int) -> bool:
|
| 221 |
+
return data[dim] == 2**exponent
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def mm_operations() -> list[AHOperation]:
|
| 225 |
+
mult_dims_ops = get_mult_dims_ops()
|
| 226 |
+
arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
|
| 227 |
+
return mult_dims_ops + [arith_intensity_op]
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def mixed_mm_operations() -> list[AHOperation]:
|
| 231 |
+
return mm_operations() + between_ops()
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def is_multiple(data: Any, dim: str, mult: int) -> bool:
|
| 235 |
+
return data[dim] % mult == 0
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def get_dims_multiple_ops() -> list[AHOperation]:
|
| 239 |
+
multiples = [2, 4, 8, 16, 32]
|
| 240 |
+
dims = ["m", "k", "n"]
|
| 241 |
+
dims_multiple_ops = []
|
| 242 |
+
for dim in dims:
|
| 243 |
+
for mult in multiples:
|
| 244 |
+
is_multiple_fn = functools.partial(is_multiple, dim=dim, mult=mult)
|
| 245 |
+
dims_multiple_op = AHOperation(
|
| 246 |
+
f"{dim}_multiple_{mult}", is_multiple_fn, is_categorical=True
|
| 247 |
+
)
|
| 248 |
+
dims_multiple_ops.append(dims_multiple_op)
|
| 249 |
+
return dims_multiple_ops
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def get_dims_need_padding_ops() -> list[AHOperation]:
|
| 253 |
+
def mat1_innermost_needs_padding_fn(data: Any) -> bool:
|
| 254 |
+
mat1_stride_0 = data["mat1_stride_0"]
|
| 255 |
+
mat1_stride_1 = data["mat1_stride_1"]
|
| 256 |
+
m_padded_length = data["m_padded_length"]
|
| 257 |
+
k_padded_length = data["k_padded_length"]
|
| 258 |
+
mat1_innermost_needs_padding = False
|
| 259 |
+
if mat1_stride_0 == 1 and m_padded_length != 0:
|
| 260 |
+
mat1_innermost_needs_padding = True
|
| 261 |
+
if mat1_stride_1 == 1 and k_padded_length != 0:
|
| 262 |
+
mat1_innermost_needs_padding = True
|
| 263 |
+
return mat1_innermost_needs_padding
|
| 264 |
+
|
| 265 |
+
mat1_innermost_op = AHOperation(
|
| 266 |
+
"mat1_innermost_needs_padding",
|
| 267 |
+
mat1_innermost_needs_padding_fn,
|
| 268 |
+
is_categorical=True,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
def mat2_innermost_needs_padding_fn(data: Any) -> bool:
|
| 272 |
+
mat2_stride_0 = data["mat2_stride_0"]
|
| 273 |
+
mat2_stride_1 = data["mat2_stride_1"]
|
| 274 |
+
k_padded_length = data["k_padded_length"]
|
| 275 |
+
n_padded_length = data["n_padded_length"]
|
| 276 |
+
mat2_innermost_needs_padding = False
|
| 277 |
+
if mat2_stride_0 == 1 and k_padded_length != 0:
|
| 278 |
+
mat2_innermost_needs_padding = True
|
| 279 |
+
if mat2_stride_1 == 1 and n_padded_length != 0:
|
| 280 |
+
mat2_innermost_needs_padding = True
|
| 281 |
+
return mat2_innermost_needs_padding
|
| 282 |
+
|
| 283 |
+
mat2_innermost_op = AHOperation(
|
| 284 |
+
"mat2_innermost_needs_padding",
|
| 285 |
+
mat2_innermost_needs_padding_fn,
|
| 286 |
+
is_categorical=True,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
def num_dims_needs_padding_fn(data: Any) -> int:
|
| 290 |
+
m_padded_length = data["m_padded_length"]
|
| 291 |
+
k_padded_length = data["k_padded_length"]
|
| 292 |
+
n_padded_length = data["n_padded_length"]
|
| 293 |
+
num_dims_needs_padding = 0
|
| 294 |
+
if m_padded_length != 0:
|
| 295 |
+
num_dims_needs_padding += 1
|
| 296 |
+
if k_padded_length != 0:
|
| 297 |
+
num_dims_needs_padding += 1
|
| 298 |
+
if n_padded_length != 0:
|
| 299 |
+
num_dims_needs_padding += 1
|
| 300 |
+
return num_dims_needs_padding
|
| 301 |
+
|
| 302 |
+
num_dims_op = AHOperation("num_dims_needs_padding", num_dims_needs_padding_fn)
|
| 303 |
+
return [mat1_innermost_op, mat2_innermost_op, num_dims_op]
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def get_is_contig_ops() -> list[AHOperation]:
|
| 307 |
+
def mat1_is_contig_fn(data: Any) -> bool:
|
| 308 |
+
stride_0 = data["mat1_stride_0"]
|
| 309 |
+
stride_1 = data["mat1_stride_1"]
|
| 310 |
+
k = data["k"]
|
| 311 |
+
return stride_0 == k and stride_1 == 1
|
| 312 |
+
|
| 313 |
+
mat1_is_contig_op = AHOperation(
|
| 314 |
+
"mat1_iscontig", mat1_is_contig_fn, is_categorical=True
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
def mat2_is_contig_fn(data: Any) -> bool:
|
| 318 |
+
stride_0 = data["mat2_stride_0"]
|
| 319 |
+
stride_1 = data["mat2_stride_1"]
|
| 320 |
+
n = data["n"]
|
| 321 |
+
return stride_0 == n and stride_1 == 1
|
| 322 |
+
|
| 323 |
+
mat2_is_contig_op = AHOperation(
|
| 324 |
+
"mat2_iscontig", mat2_is_contig_fn, is_categorical=True
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
return [mat1_is_contig_op, mat2_is_contig_op]
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def context_add_strides(context: AHContext, name: str, stride: tuple[int, ...]) -> None:
|
| 331 |
+
for i, s in enumerate(stride):
|
| 332 |
+
context.add_feature(f"{name}_stride_{i}", s)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def context_add_using_tf32(context: AHContext, dtype: torch.dtype) -> None:
|
| 336 |
+
using_tf32 = "not_float_32"
|
| 337 |
+
if dtype == torch.float32:
|
| 338 |
+
using_tf32 = torch.backends.cuda.matmul.allow_tf32
|
| 339 |
+
context.add_feature("using_tf32", using_tf32, is_categorical=True)
|
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import inspect
|
| 3 |
+
import pkgutil
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from typing import Any, Optional
|
| 6 |
+
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 8 |
+
AHContext,
|
| 9 |
+
AHMetadata,
|
| 10 |
+
Choice,
|
| 11 |
+
)
|
| 12 |
+
from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeuristic
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def find_and_instantiate_subclasses(
|
| 16 |
+
package_name: str, base_class: Any
|
| 17 |
+
) -> list[LearnedHeuristic]:
|
| 18 |
+
instances = []
|
| 19 |
+
|
| 20 |
+
package = importlib.import_module(package_name)
|
| 21 |
+
for _, module_name, _ in pkgutil.walk_packages(
|
| 22 |
+
package.__path__, package.__name__ + "."
|
| 23 |
+
):
|
| 24 |
+
try:
|
| 25 |
+
module_basename = module_name.split(".")[-1]
|
| 26 |
+
if not module_basename.startswith("_"):
|
| 27 |
+
# learned heuristics start with an underscore
|
| 28 |
+
continue
|
| 29 |
+
module = importlib.import_module(module_name)
|
| 30 |
+
|
| 31 |
+
# look for classes that are subclasses of base_class
|
| 32 |
+
for _name, obj in inspect.getmembers(module):
|
| 33 |
+
if (
|
| 34 |
+
inspect.isclass(obj)
|
| 35 |
+
and issubclass(obj, base_class)
|
| 36 |
+
and obj != base_class
|
| 37 |
+
):
|
| 38 |
+
instance = obj()
|
| 39 |
+
instances.append(instance)
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f"Error processing module {module_name}: {e}")
|
| 42 |
+
|
| 43 |
+
return instances
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class LearnedHeuristicController:
|
| 47 |
+
"""
|
| 48 |
+
Class that finds and instantiates all learned heuristics. It also provides
|
| 49 |
+
a way to get the decision of a learned heuristic.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
existing_heuristics: dict[str, list[LearnedHeuristic]] = defaultdict(list)
|
| 53 |
+
"""
|
| 54 |
+
A dictionary that stores all the learned heuristics for each optimization.
|
| 55 |
+
The key is the optimization name, and the value is a list of LearnedHeuristic objects.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
heuristics_initialized: bool = False
|
| 59 |
+
"""
|
| 60 |
+
A flag that indicates whether the learned heuristics have been initialized.
|
| 61 |
+
Set to true when the get_decision() function is called for the first time.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
metadata: AHMetadata,
|
| 67 |
+
context: AHContext,
|
| 68 |
+
) -> None:
|
| 69 |
+
self.metadata = metadata
|
| 70 |
+
self.context = context
|
| 71 |
+
|
| 72 |
+
def get_heuristics(self, name: str) -> list[LearnedHeuristic]:
|
| 73 |
+
"""
|
| 74 |
+
Returns a list of learned heuristics for the given optimization name.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
if not LearnedHeuristicController.heuristics_initialized:
|
| 78 |
+
# learned heuristics are generated into the following package
|
| 79 |
+
learned_heuristics_package = "torch._inductor.autoheuristic.artifacts"
|
| 80 |
+
|
| 81 |
+
# learned heuristics have to be of type LearnedHeuristic
|
| 82 |
+
base_class = LearnedHeuristic
|
| 83 |
+
found_heuristics = find_and_instantiate_subclasses(
|
| 84 |
+
learned_heuristics_package, base_class
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
for learned_heuristic in found_heuristics:
|
| 88 |
+
opt_name = learned_heuristic.get_name()
|
| 89 |
+
LearnedHeuristicController.existing_heuristics[opt_name].append(
|
| 90 |
+
learned_heuristic
|
| 91 |
+
)
|
| 92 |
+
LearnedHeuristicController.heuristics_initialized = True
|
| 93 |
+
|
| 94 |
+
return LearnedHeuristicController.existing_heuristics[name]
|
| 95 |
+
|
| 96 |
+
def get_decision(self) -> Optional[Choice]:
|
| 97 |
+
"""
|
| 98 |
+
Returns the decision made by the learned heuristic or None if no heuristic was found or the heuristic is unsure
|
| 99 |
+
which choice to make.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
heuristics = self.get_heuristics(self.metadata.name)
|
| 103 |
+
for heuristic in heuristics:
|
| 104 |
+
if heuristic.check_precondition(self.metadata, self.context):
|
| 105 |
+
return heuristic.get_decision(self.context, self.metadata.choices)
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
def get_decisions_ranked(self, top_k: int) -> Optional[list[Choice]]:
|
| 109 |
+
heuristics = self.get_heuristics(self.metadata.name)
|
| 110 |
+
for heuristic in heuristics:
|
| 111 |
+
if heuristic.check_precondition(self.metadata, self.context):
|
| 112 |
+
choices = heuristic.get_decisions_ranked(self.context)
|
| 113 |
+
if choices is None:
|
| 114 |
+
return None
|
| 115 |
+
avail_choices = [
|
| 116 |
+
choice for choice in choices if choice in self.metadata.choices
|
| 117 |
+
]
|
| 118 |
+
return avail_choices[:top_k]
|
| 119 |
+
return None
|
.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import operator
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 5 |
+
AHContext,
|
| 6 |
+
AHMetadata,
|
| 7 |
+
Choice,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LearnedHeuristic:
|
| 12 |
+
"""
|
| 13 |
+
LearnedHeuristic is a base class for all learned heuristics.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self) -> None:
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
def check_precondition(
|
| 20 |
+
self,
|
| 21 |
+
metadata: AHMetadata,
|
| 22 |
+
context: AHContext,
|
| 23 |
+
) -> bool:
|
| 24 |
+
return True
|
| 25 |
+
|
| 26 |
+
def get_decision(
|
| 27 |
+
self, context: AHContext, choices: list[Choice]
|
| 28 |
+
) -> Optional[Choice]:
|
| 29 |
+
return None
|
| 30 |
+
|
| 31 |
+
def get_confidence_threshold(self) -> float:
|
| 32 |
+
return 1.0
|
| 33 |
+
|
| 34 |
+
def get_name(self) -> str:
|
| 35 |
+
return ""
|
| 36 |
+
|
| 37 |
+
def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]:
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class LearnedHeuristicRegression(LearnedHeuristic):
|
| 42 |
+
def __init__(self) -> None:
|
| 43 |
+
super().__init__()
|
| 44 |
+
|
| 45 |
+
def get_feedback(self, context: AHContext, choice: Choice) -> float:
|
| 46 |
+
return 1.0
|
| 47 |
+
|
| 48 |
+
def get_decision(
|
| 49 |
+
self, context: AHContext, choices: list[Choice]
|
| 50 |
+
) -> Optional[Choice]:
|
| 51 |
+
choice2feedback = {}
|
| 52 |
+
for choice in choices:
|
| 53 |
+
predicted_feedback = self.get_feedback(context, choice)
|
| 54 |
+
choice2feedback[choice] = predicted_feedback
|
| 55 |
+
sorted_choices_feedback = sorted(
|
| 56 |
+
choice2feedback.items(), key=operator.itemgetter(1)
|
| 57 |
+
)
|
| 58 |
+
highest_feedback = sorted_choices_feedback[-1][1]
|
| 59 |
+
second_highest_feedback = sorted_choices_feedback[-2][1]
|
| 60 |
+
if highest_feedback / second_highest_feedback > self.get_confidence_threshold():
|
| 61 |
+
return sorted_choices_feedback[-1][0]
|
| 62 |
+
# We are not sure which choice is the best one
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class LearnedHeuristicDecision(LearnedHeuristic):
|
| 67 |
+
def __init__(self) -> None:
|
| 68 |
+
super().__init__()
|
| 69 |
+
|
| 70 |
+
def get_choice(self, idx: int) -> Optional[str]:
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
def get_decision(
|
| 74 |
+
self, context: AHContext, choices: list[Choice]
|
| 75 |
+
) -> Optional[Choice]:
|
| 76 |
+
best_choices = self.get_best_choices(context)
|
| 77 |
+
if not best_choices:
|
| 78 |
+
return None
|
| 79 |
+
(best_choice_proba, best_choice_idx) = best_choices[0]
|
| 80 |
+
if best_choice_proba <= self.get_confidence_threshold():
|
| 81 |
+
return None
|
| 82 |
+
return self.get_choice(best_choice_idx)
|
| 83 |
+
|
| 84 |
+
def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]:
|
| 85 |
+
feedback_idx_list = self.get_best_choices(context)
|
| 86 |
+
if feedback_idx_list is None:
|
| 87 |
+
return None
|
| 88 |
+
choices = [
|
| 89 |
+
self.get_choice(feedback_idx[1]) for feedback_idx in feedback_idx_list
|
| 90 |
+
]
|
| 91 |
+
choices = [choice for choice in choices if choice is not None]
|
| 92 |
+
return choices
|
| 93 |
+
|
| 94 |
+
def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]:
|
| 95 |
+
return []
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like:
|
| 8 |
+
# "...
|
| 9 |
+
# from ..codecache import CudaKernelParamCache
|
| 10 |
+
# ..."
|
| 11 |
+
# In such cases, we do not need to hipify_torch the original class/file name in codegen/codecache
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def maybe_hipify_code_wrapper(source_codes: str, force_hipify: bool = False) -> str:
|
| 15 |
+
if torch.version.hip is None and not force_hipify:
|
| 16 |
+
return source_codes
|
| 17 |
+
|
| 18 |
+
def c2_repl(m: re.Match[str]) -> object:
|
| 19 |
+
return PYTORCH_MAP[m.group(0)]
|
| 20 |
+
|
| 21 |
+
# We need to redefine RE_PYTORCH_PREPROCESSOR here since in hipify_torch,
|
| 22 |
+
# it will apply positive lookbehind (?<=\W) to the pattern to avoid matching
|
| 23 |
+
# keyword at the beginning of code line. However, this can happen in codegen,
|
| 24 |
+
# which will cause the pattern to not match.
|
| 25 |
+
|
| 26 |
+
# Note that lookahead (?=\W) is still needed to keep hipification idomponent, for example
|
| 27 |
+
# we need to skip replacing "getStreamFromExternal" in "getStreamFromExternalMasqueradingAsCUDA"
|
| 28 |
+
RE_PYTORCH_PREPROCESSOR = re.compile(rf"({PYTORCH_TRIE.export_to_regex()})(?=\W)")
|
| 29 |
+
|
| 30 |
+
source_codes = RE_PYTORCH_PREPROCESSOR.sub(c2_repl, source_codes) # type: ignore[arg-type]
|
| 31 |
+
return source_codes
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Definition of AOTI runtime interface functions
|
| 2 |
+
|
| 3 |
+
#include <torch/csrc/inductor/aoti_runtime/interface.h>
|
| 4 |
+
#include <torch/csrc/inductor/aoti_runtime/model_container.h>
|
| 5 |
+
|
| 6 |
+
#include <iostream>
|
| 7 |
+
#include <vector>
|
| 8 |
+
|
| 9 |
+
#define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \
|
| 10 |
+
try { \
|
| 11 |
+
__VA_ARGS__ \
|
| 12 |
+
} catch (const std::exception& e) { \
|
| 13 |
+
std::cerr << "Error: " << e.what() << '\n'; \
|
| 14 |
+
return AOTI_RUNTIME_FAILURE; \
|
| 15 |
+
} catch (...) { \
|
| 16 |
+
std::cerr << "Unknown exception occurred.\n"; \
|
| 17 |
+
return AOTI_RUNTIME_FAILURE; \
|
| 18 |
+
} \
|
| 19 |
+
return AOTI_RUNTIME_SUCCESS;
|
| 20 |
+
|
| 21 |
+
#define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name) \
|
| 22 |
+
do { \
|
| 23 |
+
AOTI_RUNTIME_CHECK( \
|
| 24 |
+
actual_size == expected_size, \
|
| 25 |
+
"expected " + std::string(name) + " vector size to be " + \
|
| 26 |
+
std::to_string(expected_size) + ", but got " + \
|
| 27 |
+
std::to_string(actual_size)); \
|
| 28 |
+
} while (0)
|
| 29 |
+
|
| 30 |
+
// AOTInductor uses at::addmm_out, which doesn't supports
|
| 31 |
+
// arguments that requires gradient. For this reason, we
|
| 32 |
+
// enforce no_grad context for run APIs.
|
| 33 |
+
//
|
| 34 |
+
// A RAII, thread local (!) guard that enables or disables grad mode upon
|
| 35 |
+
// construction, and sets it back to the original value upon destruction.
|
| 36 |
+
struct AOTINoGradGuard {
|
| 37 |
+
AOTINoGradGuard() {
|
| 38 |
+
aoti_torch_grad_mode_set_enabled(false);
|
| 39 |
+
}
|
| 40 |
+
AOTINoGradGuard(const AOTINoGradGuard&) = delete;
|
| 41 |
+
AOTINoGradGuard(AOTINoGradGuard&&) noexcept = delete;
|
| 42 |
+
~AOTINoGradGuard() {
|
| 43 |
+
aoti_torch_grad_mode_set_enabled(prev_mode);
|
| 44 |
+
}
|
| 45 |
+
AOTINoGradGuard& operator=(const AOTINoGradGuard&) = delete;
|
| 46 |
+
AOTINoGradGuard& operator=(AOTINoGradGuard&&) noexcept = delete;
|
| 47 |
+
bool prev_mode{aoti_torch_grad_mode_is_enabled()};
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
extern "C" {
|
| 51 |
+
|
| 52 |
+
AOTIRuntimeError AOTInductorModelContainerCreate(
|
| 53 |
+
AOTInductorModelContainerHandle* container_handle,
|
| 54 |
+
size_t num_models,
|
| 55 |
+
bool is_cpu,
|
| 56 |
+
const char* cubin_dir) {
|
| 57 |
+
return AOTInductorModelContainerCreateWithDevice(
|
| 58 |
+
container_handle,
|
| 59 |
+
num_models,
|
| 60 |
+
is_cpu ? "cpu" : "cuda",
|
| 61 |
+
cubin_dir);
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
AOTIRuntimeError AOTInductorModelContainerCreateWithDevice(
|
| 65 |
+
AOTInductorModelContainerHandle* container_handle,
|
| 66 |
+
size_t num_models,
|
| 67 |
+
const char* device_str,
|
| 68 |
+
const char* cubin_dir) {
|
| 69 |
+
if (num_models == 0) {
|
| 70 |
+
std::cerr << "Error: num_models must be positive, but got 0\n";
|
| 71 |
+
return AOTI_RUNTIME_FAILURE;
|
| 72 |
+
}
|
| 73 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 74 |
+
std::optional<std::string> cubin_dir_opt;
|
| 75 |
+
if (cubin_dir != nullptr) {
|
| 76 |
+
cubin_dir_opt.emplace(cubin_dir);
|
| 77 |
+
}
|
| 78 |
+
auto* container = new torch::aot_inductor::AOTInductorModelContainer(
|
| 79 |
+
num_models, std::string(device_str), cubin_dir_opt);
|
| 80 |
+
*container_handle =
|
| 81 |
+
reinterpret_cast<AOTInductorModelContainerHandle>(container);
|
| 82 |
+
})
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
AOTIRuntimeError AOTInductorModelContainerDelete(
|
| 86 |
+
AOTInductorModelContainerHandle container_handle) {
|
| 87 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 88 |
+
auto* container =
|
| 89 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 90 |
+
container_handle);
|
| 91 |
+
delete container;
|
| 92 |
+
});
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
AOTIRuntimeError AOTInductorModelContainerRun(
|
| 96 |
+
AOTInductorModelContainerHandle container_handle,
|
| 97 |
+
AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles
|
| 98 |
+
// are stolen; the array itself is borrowed
|
| 99 |
+
size_t num_inputs,
|
| 100 |
+
AtenTensorHandle*
|
| 101 |
+
output_handles, // array for writing output AtenTensorHandle; handles
|
| 102 |
+
// will be stolen by the caller; the array itself is
|
| 103 |
+
// borrowed
|
| 104 |
+
size_t num_outputs,
|
| 105 |
+
AOTInductorStreamHandle stream_handle,
|
| 106 |
+
AOTIProxyExecutorHandle proxy_executor_handle) {
|
| 107 |
+
auto* container =
|
| 108 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 109 |
+
container_handle);
|
| 110 |
+
AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs");
|
| 111 |
+
AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs");
|
| 112 |
+
|
| 113 |
+
auto stream =
|
| 114 |
+
reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
|
| 115 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 116 |
+
AOTINoGradGuard guard;
|
| 117 |
+
container->run(
|
| 118 |
+
input_handles, output_handles, stream, proxy_executor_handle);
|
| 119 |
+
})
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded(
|
| 123 |
+
AOTInductorModelContainerHandle container_handle,
|
| 124 |
+
AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles
|
| 125 |
+
// are stolen; the array itself is borrowed
|
| 126 |
+
size_t num_inputs,
|
| 127 |
+
AtenTensorHandle*
|
| 128 |
+
output_handles, // array for writing output AtenTensorHandle; handles
|
| 129 |
+
// will be stolen by the caller; the array itself is
|
| 130 |
+
// borrowed
|
| 131 |
+
size_t num_outputs,
|
| 132 |
+
AOTInductorStreamHandle stream_handle,
|
| 133 |
+
AOTIProxyExecutorHandle proxy_executor_handle) {
|
| 134 |
+
auto* container =
|
| 135 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 136 |
+
container_handle);
|
| 137 |
+
AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs");
|
| 138 |
+
AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs");
|
| 139 |
+
|
| 140 |
+
auto stream =
|
| 141 |
+
reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
|
| 142 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 143 |
+
AOTINoGradGuard guard;
|
| 144 |
+
container->run_single_threaded(
|
| 145 |
+
input_handles, output_handles, stream, proxy_executor_handle);
|
| 146 |
+
})
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
AOTIRuntimeError AOTInductorModelContainerGetNumConstants(
|
| 150 |
+
AOTInductorModelContainerHandle container_handle,
|
| 151 |
+
size_t* num_constants) {
|
| 152 |
+
auto* container =
|
| 153 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 154 |
+
container_handle);
|
| 155 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 156 |
+
{ *num_constants = container->num_constants(); })
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
AOTIRuntimeError AOTInductorModelContainerGetConstantName(
|
| 160 |
+
AOTInductorModelContainerHandle container_handle,
|
| 161 |
+
size_t idx,
|
| 162 |
+
const char** name) {
|
| 163 |
+
auto* container =
|
| 164 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 165 |
+
container_handle);
|
| 166 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 167 |
+
{ *name = container->constant_name(idx); })
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN(
|
| 171 |
+
AOTInductorModelContainerHandle container_handle,
|
| 172 |
+
size_t idx,
|
| 173 |
+
const char** original_fqn) {
|
| 174 |
+
auto* container =
|
| 175 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 176 |
+
container_handle);
|
| 177 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 178 |
+
{ *original_fqn = container->constant_original_fqn(idx); })
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded(
|
| 182 |
+
AOTInductorModelContainerHandle container_handle,
|
| 183 |
+
size_t idx,
|
| 184 |
+
bool* from_folded) {
|
| 185 |
+
auto* container =
|
| 186 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(container_handle);
|
| 187 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); })
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
AOTIRuntimeError AOTInductorModelContainerGetConstantType(
|
| 191 |
+
AOTInductorModelContainerHandle container_handle,
|
| 192 |
+
size_t idx,
|
| 193 |
+
int32_t* type) {
|
| 194 |
+
auto* container =
|
| 195 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(container_handle);
|
| 196 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({ *type = container->constant_type(idx); })
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
AOTIRuntimeError AOTInductorModelContainerGetConstantDtype(
|
| 200 |
+
AOTInductorModelContainerHandle container_handle,
|
| 201 |
+
size_t idx,
|
| 202 |
+
int32_t* dtype) {
|
| 203 |
+
auto* container =
|
| 204 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 205 |
+
container_handle);
|
| 206 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 207 |
+
{ *dtype = container->constant_dtype(idx); })
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
AOTIRuntimeError AOTInductorModelContainerGetConstantDataSize(
|
| 211 |
+
AOTInductorModelContainerHandle container_handle,
|
| 212 |
+
size_t idx,
|
| 213 |
+
size_t* data_size) {
|
| 214 |
+
auto* container =
|
| 215 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 216 |
+
container_handle);
|
| 217 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 218 |
+
{ *data_size = container->constant_data_size(idx); })
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
AOTIRuntimeError AOTInductorModelContainerExtractConstantsMap(
|
| 222 |
+
AOTInductorModelContainerHandle container_handle,
|
| 223 |
+
AOTInductorConstantMapHandle constant_map_handle,
|
| 224 |
+
bool use_inactive) {
|
| 225 |
+
auto* container =
|
| 226 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 227 |
+
container_handle);
|
| 228 |
+
auto constants_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
|
| 229 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 230 |
+
{ const auto ret = container->extract_constants_map(use_inactive);
|
| 231 |
+
for (const auto& pair: ret) {
|
| 232 |
+
constants_map->emplace(pair.first, pair.second);
|
| 233 |
+
}
|
| 234 |
+
})
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBuffer(
|
| 238 |
+
AOTInductorModelContainerHandle container_handle,
|
| 239 |
+
AOTInductorConstantMapHandle constant_map_handle,
|
| 240 |
+
bool use_inactive,
|
| 241 |
+
bool validate_full_update) {
|
| 242 |
+
auto* container =
|
| 243 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 244 |
+
container_handle);
|
| 245 |
+
auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
|
| 246 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 247 |
+
container->update_constant_buffer(
|
| 248 |
+
*input_map, use_inactive, validate_full_update, /* user_managed = */ true);
|
| 249 |
+
})
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer(
|
| 253 |
+
AOTInductorModelContainerHandle container_handle,
|
| 254 |
+
AOTInductorConstantMapHandle constant_map_handle,
|
| 255 |
+
bool use_inactive,
|
| 256 |
+
bool validate_full_update) {
|
| 257 |
+
auto* container =
|
| 258 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 259 |
+
container_handle);
|
| 260 |
+
auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
|
| 261 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 262 |
+
container->update_constant_buffer(
|
| 263 |
+
*input_map, use_inactive, validate_full_update);
|
| 264 |
+
})
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer(
|
| 268 |
+
AOTInductorModelContainerHandle container_handle,
|
| 269 |
+
AOTInductorConstantMapHandle constant_map_handle) {
|
| 270 |
+
return AOTInductorModelContainerUpdateConstantBuffer(container_handle,
|
| 271 |
+
constant_map_handle,
|
| 272 |
+
/*use_inactive*/ true,
|
| 273 |
+
/*validate_full_update*/ true);
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
AOTIRuntimeError AOTInductorModelContainerFreeInactiveConstantBuffer(
|
| 277 |
+
AOTInductorModelContainerHandle container_handle) {
|
| 278 |
+
auto* container =
|
| 279 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 280 |
+
container_handle);
|
| 281 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 282 |
+
container->free_inactive_constant_buffer();
|
| 283 |
+
})
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
AOTIRuntimeError AOTInductorModelContainerRunConstantFolding(
|
| 287 |
+
AOTInductorModelContainerHandle container_handle,
|
| 288 |
+
bool use_inactive,
|
| 289 |
+
AOTInductorStreamHandle stream_handle,
|
| 290 |
+
AOTIProxyExecutorHandle proxy_executor_handle) {
|
| 291 |
+
auto* container =
|
| 292 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 293 |
+
container_handle);
|
| 294 |
+
auto stream =
|
| 295 |
+
reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
|
| 296 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 297 |
+
AOTINoGradGuard guard;
|
| 298 |
+
container->run_const_fold(use_inactive, stream, proxy_executor_handle);
|
| 299 |
+
})
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer(
|
| 303 |
+
AOTInductorModelContainerHandle container_handle) {
|
| 304 |
+
auto* container =
|
| 305 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 306 |
+
container_handle);
|
| 307 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 308 |
+
container->swap_constant_buffer();
|
| 309 |
+
})
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
AOTIRuntimeError AOTInductorModelContainerGetNumInputs(
|
| 313 |
+
AOTInductorModelContainerHandle container_handle,
|
| 314 |
+
size_t* ret_num_inputs) {
|
| 315 |
+
auto* container =
|
| 316 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 317 |
+
container_handle);
|
| 318 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 319 |
+
{ *ret_num_inputs = container->num_inputs(); })
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
AOTIRuntimeError AOTInductorModelContainerGetInputName(
|
| 323 |
+
AOTInductorModelContainerHandle container_handle,
|
| 324 |
+
size_t input_idx,
|
| 325 |
+
const char** ret_input_names) {
|
| 326 |
+
auto* container =
|
| 327 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 328 |
+
container_handle);
|
| 329 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 330 |
+
{ *ret_input_names = container->input_name(input_idx); })
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
AOTIRuntimeError AOTInductorModelContainerGetNumOutputs(
|
| 334 |
+
AOTInductorModelContainerHandle container_handle,
|
| 335 |
+
size_t* ret_num_outputs) {
|
| 336 |
+
auto* container =
|
| 337 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 338 |
+
container_handle);
|
| 339 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 340 |
+
{ *ret_num_outputs = container->num_outputs(); })
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
AOTIRuntimeError AOTInductorModelContainerGetOutputName(
|
| 344 |
+
AOTInductorModelContainerHandle container_handle,
|
| 345 |
+
size_t output_idx,
|
| 346 |
+
const char** ret_output_names) {
|
| 347 |
+
auto* container =
|
| 348 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 349 |
+
container_handle);
|
| 350 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 351 |
+
{ *ret_output_names = container->output_name(output_idx); })
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
AOTIRuntimeError AOTInductorModelContainerGetCallSpec(
|
| 355 |
+
AOTInductorModelContainerHandle container_handle,
|
| 356 |
+
const char** in_spec,
|
| 357 |
+
const char** out_spec) {
|
| 358 |
+
auto* container =
|
| 359 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 360 |
+
container_handle);
|
| 361 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 362 |
+
*in_spec = container->get_in_spec();
|
| 363 |
+
*out_spec = container->get_out_spec();
|
| 364 |
+
})
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
AOTIRuntimeError AOTInductorModelCreate(
|
| 368 |
+
AOTInductorModelHandle* model_handle,
|
| 369 |
+
AOTInductorConstantMapHandle constant_map_handle){
|
| 370 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 371 |
+
auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
|
| 372 |
+
auto constant_array = std::make_shared<std::vector<torch::aot_inductor::ConstantHandle>>();
|
| 373 |
+
auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
|
| 374 |
+
|
| 375 |
+
auto model = new torch::aot_inductor::AOTInductorModel(
|
| 376 |
+
constant_map,
|
| 377 |
+
constant_array,
|
| 378 |
+
"cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models
|
| 379 |
+
""
|
| 380 |
+
);
|
| 381 |
+
|
| 382 |
+
if (input_map) {
|
| 383 |
+
for (auto const& kv : *input_map) {
|
| 384 |
+
constant_map->emplace(kv.first, kv.second);
|
| 385 |
+
}
|
| 386 |
+
} else {
|
| 387 |
+
model->load_constants();
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
*model_handle = reinterpret_cast<AOTInductorModelHandle>(model);
|
| 391 |
+
})}
|
| 392 |
+
|
| 393 |
+
AOTIRuntimeError AOTInductorModelRun(
|
| 394 |
+
AOTInductorModelHandle model_handle,
|
| 395 |
+
AtenTensorHandle* input_handles,
|
| 396 |
+
AtenTensorHandle* output_handles) {
|
| 397 |
+
auto model =
|
| 398 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
|
| 399 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 400 |
+
AOTINoGradGuard guard;
|
| 401 |
+
model->run_impl(
|
| 402 |
+
input_handles,
|
| 403 |
+
output_handles,
|
| 404 |
+
(torch::aot_inductor::DeviceStreamType) nullptr,
|
| 405 |
+
nullptr);
|
| 406 |
+
})
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){
|
| 410 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 411 |
+
auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(
|
| 412 |
+
model_handle);
|
| 413 |
+
delete model;
|
| 414 |
+
})}
|
| 415 |
+
|
| 416 |
+
AOTIRuntimeError AOTInductorModelGetNumOutputs(
|
| 417 |
+
AOTInductorModelHandle model_handle,
|
| 418 |
+
size_t* ret_num_outputs) {
|
| 419 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 420 |
+
auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
|
| 421 |
+
*ret_num_outputs = model->num_outputs();
|
| 422 |
+
})
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
AOTIRuntimeError AOTInductorModelUpdateConstantsMap(
|
| 426 |
+
AOTInductorModelHandle model_handle,
|
| 427 |
+
AOTInductorConstantMapHandle constant_map_handle) {
|
| 428 |
+
auto model =
|
| 429 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
|
| 430 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 431 |
+
auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
|
| 432 |
+
auto input_map =
|
| 433 |
+
reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(
|
| 434 |
+
constant_map_handle);
|
| 435 |
+
|
| 436 |
+
for (auto const& kv : *input_map) {
|
| 437 |
+
constant_map->emplace(kv.first, kv.second);
|
| 438 |
+
}
|
| 439 |
+
model->update_constants_map(std::move(constant_map));
|
| 440 |
+
})
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
} // extern "C"
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/block_analysis.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import functools
|
| 3 |
+
import textwrap
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import sympy
|
| 7 |
+
from sympy import Expr, Symbol
|
| 8 |
+
|
| 9 |
+
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
| 10 |
+
|
| 11 |
+
from ..utils import sympy_dot, sympy_subs
|
| 12 |
+
from ..virtualized import V
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BlockPatternMatcher:
|
| 16 |
+
"""
|
| 17 |
+
Matches block indexing expressions.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
@classmethod
|
| 21 |
+
def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr:
|
| 22 |
+
"""
|
| 23 |
+
Given a sympy expression, return the subexpression comprised only of terms
|
| 24 |
+
involving the specified symbol.
|
| 25 |
+
|
| 26 |
+
For example, if `expr` is `x * 5 + x ** 2 + y * 2 + 5`, and `symbol` is `x`,
|
| 27 |
+
this returns `x * 5 + x ** 2`.
|
| 28 |
+
"""
|
| 29 |
+
expr = cls._preprocess(expr)
|
| 30 |
+
return sympy.S.Zero + sum(
|
| 31 |
+
term for term in sympy.Add.make_args(expr) if symbol in term.free_symbols
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def get_slice_numels(dims: list[Expr]) -> list[Expr]:
|
| 36 |
+
"""
|
| 37 |
+
Compute the cumulative size of each dimension's slice.
|
| 38 |
+
This proceeds from the last dim up to the second.
|
| 39 |
+
"""
|
| 40 |
+
numels = collections.deque([sympy.S.One])
|
| 41 |
+
for dim in dims[:0:-1]:
|
| 42 |
+
numel = dim * numels[0]
|
| 43 |
+
numels.appendleft(numel)
|
| 44 |
+
return [*numels]
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def _preprocess(expr: Expr) -> Expr:
|
| 48 |
+
# Remove any Identity nodes, e.g. expand x + (5 * y) to x + 5 * y.
|
| 49 |
+
return expr.expand(identity=True)
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
def match_mod_div_block_expr(
|
| 53 |
+
cls,
|
| 54 |
+
index: Expr,
|
| 55 |
+
index_var: Symbol,
|
| 56 |
+
numel: Expr,
|
| 57 |
+
num_dims: int,
|
| 58 |
+
) -> Optional[tuple[list[Expr], list[Expr], list[Expr]]]:
|
| 59 |
+
"""
|
| 60 |
+
Matches modular indexing expressions, converting them to implied block dimensions and strides.
|
| 61 |
+
See triton.py for more information.
|
| 62 |
+
"""
|
| 63 |
+
index = cls._preprocess(index)
|
| 64 |
+
|
| 65 |
+
# Pattern match to find the strides and offset.
|
| 66 |
+
wild = functools.partial(sympy.Wild, exclude=[index_var])
|
| 67 |
+
dims: list[Expr] = [wild(f"dim_mod{idx}") for idx in range(num_dims)]
|
| 68 |
+
strides: list[Expr] = [wild(f"stride_mod{idx}") for idx in range(num_dims)]
|
| 69 |
+
|
| 70 |
+
# The first dimension's index is computed by division.
|
| 71 |
+
# The remaining are computed by modulo.
|
| 72 |
+
slice_numels = cls.get_slice_numels(dims[:num_dims])
|
| 73 |
+
block_index_exprs = [FloorDiv(index_var, slice_numels[0])] + [
|
| 74 |
+
ModularIndexing(index_var, numel, dim)
|
| 75 |
+
for dim, numel in zip(dims[1:], slice_numels[1:])
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
# Calculate a linear index from block indices.
|
| 79 |
+
match_expr = sympy_dot(strides, block_index_exprs)
|
| 80 |
+
|
| 81 |
+
# Heuristic: if the number of dimensions is high, check that the minimum requirements
|
| 82 |
+
# are met before attempting an expensive full match. see triton.py:match_mod_div_block
|
| 83 |
+
# for more details. In short, here we check that each subexpression in sympy.Add contains
|
| 84 |
+
# only FloorDiv or ModularIndexing expressions.
|
| 85 |
+
if num_dims >= 5:
|
| 86 |
+
stride, denom, other = sympy.symbols("stride denominator other", cls=wild)
|
| 87 |
+
mod_div_pattern = stride * ModularIndexing(index_var, denom, other)
|
| 88 |
+
floor_div_pattern = stride * FloorDiv(index_var, denom)
|
| 89 |
+
first_dim_floor_div_matched = False
|
| 90 |
+
match_failed = False
|
| 91 |
+
for arg in sympy.Add.make_args(index):
|
| 92 |
+
if arg.match(floor_div_pattern):
|
| 93 |
+
# There should only be a single FloorDiv(index, denom) expression
|
| 94 |
+
# corresponding to the first dimension
|
| 95 |
+
if first_dim_floor_div_matched:
|
| 96 |
+
match_failed = True
|
| 97 |
+
break
|
| 98 |
+
first_dim_floor_div_matched = True
|
| 99 |
+
elif arg.match(mod_div_pattern):
|
| 100 |
+
continue
|
| 101 |
+
else:
|
| 102 |
+
match_failed = True
|
| 103 |
+
break
|
| 104 |
+
|
| 105 |
+
if match_failed:
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
# Pattern match.
|
| 109 |
+
match = index.match(match_expr)
|
| 110 |
+
if match is None:
|
| 111 |
+
return None
|
| 112 |
+
|
| 113 |
+
# Provide default values for unmatched dims and strides.
|
| 114 |
+
for dim in dims[1:]:
|
| 115 |
+
if dim not in match:
|
| 116 |
+
match[dim] = sympy.S.One
|
| 117 |
+
for stride in strides[1:]:
|
| 118 |
+
if stride not in match:
|
| 119 |
+
match[stride] = sympy.S.Zero
|
| 120 |
+
|
| 121 |
+
sizevars = V.graph.sizevars
|
| 122 |
+
|
| 123 |
+
def get_match(expr: Expr) -> Expr:
|
| 124 |
+
return sizevars.lookup_precomputed_size(match[expr])
|
| 125 |
+
|
| 126 |
+
# Replace wildcards with matched expressions.
|
| 127 |
+
dims = [dims[0]] + [get_match(dim) for dim in dims[1:]]
|
| 128 |
+
strides = [get_match(stride) for stride in strides]
|
| 129 |
+
slice_numels = cls.get_slice_numels(dims)
|
| 130 |
+
block_index_exprs = [sympy_subs(expr, match) for expr in block_index_exprs]
|
| 131 |
+
|
| 132 |
+
# The leading dimension is not directly matched in our expression.
|
| 133 |
+
# We solve for it by dividing the range tree numel by the product of
|
| 134 |
+
# all other dimensions. We quit if they are not known to be divisible.
|
| 135 |
+
assert dims[0] not in match, "Expected not to match the leading dimension!"
|
| 136 |
+
if not sizevars.statically_known_multiple_of(numel, slice_numels[0]):
|
| 137 |
+
return None
|
| 138 |
+
dims[0] = numel / slice_numels[0]
|
| 139 |
+
|
| 140 |
+
# Sanity check that we can recover the index from the matched subexpressions.
|
| 141 |
+
matched_index = sympy_dot(strides, block_index_exprs)
|
| 142 |
+
assert sizevars.statically_known_equals(
|
| 143 |
+
# New precomputed replacements may be generated when the `get_match` function
|
| 144 |
+
# above is called, but the `index` that is being matched has not been updated.
|
| 145 |
+
# So remove them when checking for equivalence e.g. if ps0=3*s0 and
|
| 146 |
+
# index=3*s0*expr, matched_index=ps0*expr, then index == matched_index
|
| 147 |
+
sizevars.remove_precomputed_replacements(matched_index),
|
| 148 |
+
sizevars.remove_precomputed_replacements(index),
|
| 149 |
+
), textwrap.dedent(
|
| 150 |
+
f"""
|
| 151 |
+
Invalid match!
|
| 152 |
+
Index: {index}
|
| 153 |
+
Matched expression: {matched_index}
|
| 154 |
+
"""
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
return dims, strides, block_index_exprs
|
| 158 |
+
|
| 159 |
+
@classmethod
|
| 160 |
+
def match_affine_block_expr(
|
| 161 |
+
cls,
|
| 162 |
+
index: Expr,
|
| 163 |
+
index_var: Symbol,
|
| 164 |
+
) -> Optional[Expr]:
|
| 165 |
+
"""
|
| 166 |
+
Matches simple expressions of the form stride * index, returning the
|
| 167 |
+
stride.
|
| 168 |
+
"""
|
| 169 |
+
index = cls._preprocess(index)
|
| 170 |
+
stride = sympy.Wild("stride", exclude=[index_var])
|
| 171 |
+
m = index.match(index_var * stride)
|
| 172 |
+
if m is None:
|
| 173 |
+
return None
|
| 174 |
+
|
| 175 |
+
return m[stride]
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/common.py
ADDED
|
@@ -0,0 +1,2691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import atexit
|
| 4 |
+
import contextlib
|
| 5 |
+
import dataclasses
|
| 6 |
+
import enum
|
| 7 |
+
import functools
|
| 8 |
+
import itertools
|
| 9 |
+
import logging
|
| 10 |
+
import math
|
| 11 |
+
import operator
|
| 12 |
+
import os
|
| 13 |
+
import re
|
| 14 |
+
import tempfile
|
| 15 |
+
from abc import ABC, abstractmethod
|
| 16 |
+
from enum import auto, Enum
|
| 17 |
+
from itertools import chain
|
| 18 |
+
from typing import (
|
| 19 |
+
Any,
|
| 20 |
+
Callable,
|
| 21 |
+
cast,
|
| 22 |
+
ClassVar,
|
| 23 |
+
Generic,
|
| 24 |
+
NamedTuple,
|
| 25 |
+
Optional,
|
| 26 |
+
TYPE_CHECKING,
|
| 27 |
+
Union,
|
| 28 |
+
)
|
| 29 |
+
from typing_extensions import Self, TypeVar
|
| 30 |
+
|
| 31 |
+
import sympy
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
import torch.fx
|
| 35 |
+
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
| 36 |
+
from torch.utils import _pytree as pytree
|
| 37 |
+
from torch.utils._ordered_set import OrderedSet
|
| 38 |
+
from torch.utils._sympy.numbers import int_oo
|
| 39 |
+
from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter
|
| 40 |
+
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
| 41 |
+
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
|
| 42 |
+
|
| 43 |
+
from .. import config, metrics
|
| 44 |
+
from ..dtype_propagation import DtypePropagationOpsHandler
|
| 45 |
+
from ..ops_handler import BasicMathOpsMixin, DefaultHandler
|
| 46 |
+
from ..utils import (
|
| 47 |
+
boolean_ops,
|
| 48 |
+
DeferredLineBase,
|
| 49 |
+
generate_assert,
|
| 50 |
+
get_current_backend,
|
| 51 |
+
IndentedBuffer,
|
| 52 |
+
ir_dataclass,
|
| 53 |
+
ScopedDict,
|
| 54 |
+
sympy_dot,
|
| 55 |
+
sympy_index_symbol,
|
| 56 |
+
sympy_subs,
|
| 57 |
+
triton_type,
|
| 58 |
+
unique,
|
| 59 |
+
)
|
| 60 |
+
from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if TYPE_CHECKING:
|
| 64 |
+
from collections.abc import Iterator, MutableMapping, Sequence
|
| 65 |
+
|
| 66 |
+
from torch.fx import GraphModule
|
| 67 |
+
|
| 68 |
+
from ..custom_graph_pass import CustomGraphModulePass
|
| 69 |
+
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
|
| 70 |
+
from ..loop_body import LoopBody
|
| 71 |
+
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
|
| 72 |
+
from .wrapper import PythonWrapperCodegen
|
| 73 |
+
|
| 74 |
+
_T = TypeVar("_T")
|
| 75 |
+
SchedulingConstructor = Callable[[Optional[Scheduler]], BaseScheduling]
|
| 76 |
+
WrapperConstructor = type[PythonWrapperCodegen]
|
| 77 |
+
SymbolLike = Union[str, sympy.Symbol]
|
| 78 |
+
|
| 79 |
+
# OpVarT should really be Union[CSEVariable, str], however this
|
| 80 |
+
# causes typing errors in subclasses (defined in other files).
|
| 81 |
+
OpVarT = str
|
| 82 |
+
|
| 83 |
+
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
|
| 84 |
+
log = logging.getLogger(__name__)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def data_type_logger(msg: str) -> None:
|
| 88 |
+
if schedule_log.isEnabledFor(logging.DEBUG):
|
| 89 |
+
schedule_log.debug("Data type propagation: %s", msg)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclasses.dataclass
|
| 93 |
+
class FileBackedGraphModule:
|
| 94 |
+
"""
|
| 95 |
+
Output of FX wrapper codegen. Exposes the same methods as ModuleType, but these
|
| 96 |
+
map back to a GraphModule instead of Python source.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
gm: GraphModule
|
| 100 |
+
compiled_fn: Callable[..., Any]
|
| 101 |
+
|
| 102 |
+
def __post_init__(self) -> None:
|
| 103 |
+
# Write the code to a file for compatibility with debugging utilities.
|
| 104 |
+
# The file is deleted upon program termination.
|
| 105 |
+
self.tempfile = tempfile.NamedTemporaryFile(
|
| 106 |
+
mode="w+", suffix=".py", delete=False
|
| 107 |
+
)
|
| 108 |
+
atexit.register(os.remove, self.tempfile.name)
|
| 109 |
+
with self.tempfile as f:
|
| 110 |
+
f.write(self.value)
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def __file__(self) -> str:
|
| 114 |
+
return self.tempfile.name
|
| 115 |
+
|
| 116 |
+
def call(self, args: list[Any]) -> Any:
|
| 117 |
+
return self.compiled_fn(*args)
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def value(self) -> str:
|
| 121 |
+
return self.gm.code
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class WorkspaceZeroMode(enum.Enum):
|
| 125 |
+
UNINITIALIZED = 0
|
| 126 |
+
ZERO_ON_CALL = 1 # kernel may leave workspace dirty
|
| 127 |
+
ZERO_PER_GRAPH = 2 # must be re-zeroed by kernel
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def combine(a: WorkspaceZeroMode, b: WorkspaceZeroMode) -> WorkspaceZeroMode:
|
| 131 |
+
if a == b or b == WorkspaceZeroMode.UNINITIALIZED:
|
| 132 |
+
return a
|
| 133 |
+
if a == WorkspaceZeroMode.UNINITIALIZED:
|
| 134 |
+
return b
|
| 135 |
+
raise NotImplementedError(f"WorkspaceZeroMode.combine({a!r}, {b!r})")
|
| 136 |
+
|
| 137 |
+
@staticmethod
|
| 138 |
+
def from_bool(zero_fill: bool) -> WorkspaceZeroMode:
|
| 139 |
+
if zero_fill:
|
| 140 |
+
return WorkspaceZeroMode.ZERO_ON_CALL
|
| 141 |
+
return WorkspaceZeroMode.UNINITIALIZED
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class CodegenSymbol(ABC):
|
| 145 |
+
"""
|
| 146 |
+
An IR object possibly corresponding to a variable in the wrapper code.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
@abstractmethod
|
| 150 |
+
def get_name(self) -> str:
|
| 151 |
+
pass
|
| 152 |
+
|
| 153 |
+
@abstractmethod
|
| 154 |
+
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
|
| 155 |
+
pass
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@ir_dataclass(frozen=True)
|
| 159 |
+
class WorkspaceArg(CodegenSymbol):
|
| 160 |
+
"""A temporary buffer used for a single kernel, then discarded.
|
| 161 |
+
|
| 162 |
+
Not registered as a traditional buffer since there are no users,
|
| 163 |
+
so it would be dead code eliminated.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
nbytes: The size of the buffer in bytes.
|
| 167 |
+
zero_fill: Whether the buffer should be initialized to zero.
|
| 168 |
+
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
count: sympy.Expr
|
| 172 |
+
zero_mode: WorkspaceZeroMode
|
| 173 |
+
device: torch.device
|
| 174 |
+
outer_name: str
|
| 175 |
+
inner_name: str = "ws_ptr"
|
| 176 |
+
dtype: torch.dtype = torch.uint8
|
| 177 |
+
|
| 178 |
+
@staticmethod
|
| 179 |
+
def unique_name(prefix: str = "workspace_") -> str:
|
| 180 |
+
return f"{prefix}{next(V.graph.workspace_id)}"
|
| 181 |
+
|
| 182 |
+
@staticmethod
|
| 183 |
+
def can_join(a: WorkspaceArg, b: WorkspaceArg) -> bool:
|
| 184 |
+
return (
|
| 185 |
+
a.inner_name == b.inner_name and a.dtype == b.dtype and a.device == b.device
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
@staticmethod
|
| 189 |
+
def join(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg:
|
| 190 |
+
return WorkspaceArg(
|
| 191 |
+
count=a.count + b.count,
|
| 192 |
+
zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode),
|
| 193 |
+
dtype=a.dtype,
|
| 194 |
+
device=a.device,
|
| 195 |
+
inner_name=a.inner_name,
|
| 196 |
+
outer_name=a.outer_name,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
@staticmethod
|
| 200 |
+
def maximum(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg:
|
| 201 |
+
assert (
|
| 202 |
+
a.dtype == b.dtype and a.device == b.device and a.inner_name == b.inner_name
|
| 203 |
+
)
|
| 204 |
+
return WorkspaceArg(
|
| 205 |
+
count=sympy.Max(a.count, b.count),
|
| 206 |
+
zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode),
|
| 207 |
+
dtype=a.dtype,
|
| 208 |
+
device=a.device,
|
| 209 |
+
inner_name=a.inner_name,
|
| 210 |
+
outer_name=a.outer_name,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# These methods let WorkspaceArg pretend it is a buffer to reuse allocation code
|
| 214 |
+
def get_device(self) -> torch.device:
|
| 215 |
+
return self.device
|
| 216 |
+
|
| 217 |
+
get_device_or_error = get_device
|
| 218 |
+
|
| 219 |
+
def get_dtype(self) -> torch.dtype:
|
| 220 |
+
return self.dtype
|
| 221 |
+
|
| 222 |
+
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
|
| 223 |
+
return self.get_layout().get_example()
|
| 224 |
+
|
| 225 |
+
def get_layout(self) -> FixedLayout:
|
| 226 |
+
from ..ir import FixedLayout
|
| 227 |
+
|
| 228 |
+
return FixedLayout(
|
| 229 |
+
device=self.device,
|
| 230 |
+
dtype=self.dtype,
|
| 231 |
+
size=[self.count],
|
| 232 |
+
stride=[1],
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
@property
|
| 236 |
+
def layout(self) -> FixedLayout:
|
| 237 |
+
return self.get_layout()
|
| 238 |
+
|
| 239 |
+
get_output_spec = get_layout
|
| 240 |
+
maybe_get_output_spec = get_layout
|
| 241 |
+
maybe_get_layout = get_layout
|
| 242 |
+
|
| 243 |
+
def get_offset(self) -> sympy.Expr:
|
| 244 |
+
return sympy.S.Zero
|
| 245 |
+
|
| 246 |
+
def get_size(self) -> list[sympy.Expr]:
|
| 247 |
+
return [self.count]
|
| 248 |
+
|
| 249 |
+
def get_stride(self) -> list[sympy.Expr]:
|
| 250 |
+
return [sympy.S.One]
|
| 251 |
+
|
| 252 |
+
def get_name(self) -> str:
|
| 253 |
+
return self.outer_name
|
| 254 |
+
|
| 255 |
+
def get_inputs_that_alias_output(self) -> list[str]:
|
| 256 |
+
return []
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class TritonScratchWorkspace:
|
| 260 |
+
def __init__(self, size: int, generate_dtype_str: Callable[..., str]):
|
| 261 |
+
self.size = size
|
| 262 |
+
self._generate_dtype_str = generate_dtype_str
|
| 263 |
+
|
| 264 |
+
def generate_dtype_str(self) -> str:
|
| 265 |
+
return self._generate_dtype_str()
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@dataclasses.dataclass
|
| 269 |
+
class TensorArg:
|
| 270 |
+
name: str
|
| 271 |
+
buffer: str
|
| 272 |
+
dtype: torch.dtype
|
| 273 |
+
offset: sympy.Expr = sympy.S.Zero # c++ only
|
| 274 |
+
alias_of: Optional[str] = None # halide only
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
@dataclasses.dataclass
|
| 278 |
+
class SizeArg:
|
| 279 |
+
name: str
|
| 280 |
+
expr: sympy.Expr
|
| 281 |
+
|
| 282 |
+
@property
|
| 283 |
+
def alias_of(self) -> Optional[str]:
|
| 284 |
+
return None
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
@dataclasses.dataclass
|
| 288 |
+
class ConstexprArg:
|
| 289 |
+
name: str
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
@dataclasses.dataclass
|
| 293 |
+
class TMADescriptorArg:
|
| 294 |
+
name: str
|
| 295 |
+
api_type: str # "experimental" or "stable"
|
| 296 |
+
block_shape: Optional[list[sympy.Expr]] # only needed for "stable"
|
| 297 |
+
dtype: Optional[torch.dtype] # only needed for "stable"
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@dataclasses.dataclass
|
| 301 |
+
class DeviceCodegen:
|
| 302 |
+
scheduling: SchedulingConstructor
|
| 303 |
+
wrapper_codegen: WrapperConstructor
|
| 304 |
+
cpp_wrapper_codegen: Optional[WrapperConstructor] = None
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg, ConstexprArg]
|
| 308 |
+
|
| 309 |
+
device_codegens: dict[str, DeviceCodegen] = {}
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class DeviceOpOverrides:
|
| 313 |
+
def import_get_raw_stream_as(self, name: str) -> str:
|
| 314 |
+
raise NotImplementedError
|
| 315 |
+
|
| 316 |
+
def set_device(self, device_idx: int) -> str:
|
| 317 |
+
raise NotImplementedError
|
| 318 |
+
|
| 319 |
+
def synchronize(self) -> str:
|
| 320 |
+
raise NotImplementedError
|
| 321 |
+
|
| 322 |
+
def device_guard(self, device_idx: int) -> str:
|
| 323 |
+
raise NotImplementedError
|
| 324 |
+
|
| 325 |
+
def cpp_device_guard(self) -> str:
|
| 326 |
+
raise NotImplementedError
|
| 327 |
+
|
| 328 |
+
def cpp_aoti_device_guard(self) -> str:
|
| 329 |
+
raise NotImplementedError
|
| 330 |
+
|
| 331 |
+
def cpp_stream_guard(self) -> str:
|
| 332 |
+
raise NotImplementedError
|
| 333 |
+
|
| 334 |
+
def cpp_aoti_stream_guard(self) -> str:
|
| 335 |
+
raise NotImplementedError
|
| 336 |
+
|
| 337 |
+
def cpp_getStreamFromExternal(self) -> str:
|
| 338 |
+
raise NotImplementedError
|
| 339 |
+
|
| 340 |
+
def kernel_header(self) -> str:
|
| 341 |
+
raise NotImplementedError
|
| 342 |
+
|
| 343 |
+
def kernel_driver(self) -> str:
|
| 344 |
+
raise NotImplementedError
|
| 345 |
+
|
| 346 |
+
def cpp_stream_type(self) -> str:
|
| 347 |
+
raise NotImplementedError
|
| 348 |
+
|
| 349 |
+
def aoti_get_stream(self) -> str:
|
| 350 |
+
raise NotImplementedError
|
| 351 |
+
|
| 352 |
+
def cpp_kernel_type(self) -> str:
|
| 353 |
+
raise NotImplementedError
|
| 354 |
+
|
| 355 |
+
def cpp_device_ptr(self) -> str:
|
| 356 |
+
raise NotImplementedError
|
| 357 |
+
|
| 358 |
+
def tma_descriptor_helpers(self) -> str:
|
| 359 |
+
raise NotImplementedError
|
| 360 |
+
|
| 361 |
+
def cpp_global_scratch(
|
| 362 |
+
self, idx: int, workspace: TritonScratchWorkspace
|
| 363 |
+
) -> Optional[tuple[list[str], str]]:
|
| 364 |
+
# optionally return (scratch definition, arg name)
|
| 365 |
+
raise NotImplementedError
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
device_op_overrides_dict: dict[str, DeviceOpOverrides] = {}
|
| 369 |
+
custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {}
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
|
| 373 |
+
# For any new backend looking to integrate with Inductor, customization of these two main
|
| 374 |
+
# parts are necessary to generate its specific code.
|
| 375 |
+
#
|
| 376 |
+
# Kernel code generation is determined by different Scheduling. Consequently, a new
|
| 377 |
+
# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
|
| 378 |
+
# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
|
| 379 |
+
#
|
| 380 |
+
# For the Wrapper, Inductor provides a PythonWrapperCodegen class to generate the Python wrapper code
|
| 381 |
+
# that bridges kernels. This allows out-of-tree backends to inherit from PythonWrapperCodegen,
|
| 382 |
+
# and override specific member functions to create backend-specific Python wrapper code.
|
| 383 |
+
#
|
| 384 |
+
# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
|
| 385 |
+
# of the logic for either Scheduling or PythonWrapperCodegen. So the Scheduling and PythonWrapperCodegen interfaces
|
| 386 |
+
# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
|
| 387 |
+
# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
|
| 388 |
+
# register_backend_for_device, to equip a new backend at runtime.
|
| 389 |
+
#
|
| 390 |
+
# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
|
| 391 |
+
# This backend can be used as a reference:
|
| 392 |
+
# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
|
| 393 |
+
def register_backend_for_device(
|
| 394 |
+
device: str,
|
| 395 |
+
device_scheduling: SchedulingConstructor,
|
| 396 |
+
device_wrapper_codegen: WrapperConstructor,
|
| 397 |
+
device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None,
|
| 398 |
+
device_custom_pass: Optional[CustomGraphModulePass] = None,
|
| 399 |
+
) -> None:
|
| 400 |
+
device_codegens[device] = DeviceCodegen(
|
| 401 |
+
device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
|
| 402 |
+
)
|
| 403 |
+
custom_backend_passes[device] = device_custom_pass
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
class BackendFeature(Enum):
|
| 407 |
+
FOREACH = auto()
|
| 408 |
+
BUCKETIZE = auto()
|
| 409 |
+
INPLACE_BUFFERS = auto()
|
| 410 |
+
MASKED_SCATTER_WITH_INDEX = auto()
|
| 411 |
+
SCAN = auto()
|
| 412 |
+
SORT = auto()
|
| 413 |
+
TUPLE_REDUCTION = auto()
|
| 414 |
+
PREFER_STORE_LOOP_ORDER = auto()
|
| 415 |
+
TRITON_TEMPLATES = auto()
|
| 416 |
+
REDUCE_TO_SINGLE_ELEMENT = auto()
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def get_backend_features(
|
| 420 |
+
device: Union[torch.device, str, None],
|
| 421 |
+
) -> OrderedSet[BackendFeature]:
|
| 422 |
+
if device is None:
|
| 423 |
+
return OrderedSet()
|
| 424 |
+
init_backend_registration()
|
| 425 |
+
if isinstance(device, torch.device):
|
| 426 |
+
device_type = device.type
|
| 427 |
+
else:
|
| 428 |
+
assert isinstance(device, str), type(device)
|
| 429 |
+
device_type = device
|
| 430 |
+
device = torch.device(device_type)
|
| 431 |
+
scheduling_ctor = get_scheduling_for_device(device_type)
|
| 432 |
+
assert scheduling_ctor
|
| 433 |
+
scheduling = scheduling_ctor(None)
|
| 434 |
+
return scheduling.get_backend_features(device)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def has_backend_feature(
|
| 438 |
+
device: Union[torch.device, str, None], feature: BackendFeature
|
| 439 |
+
) -> bool:
|
| 440 |
+
"""See also V.graph.has_feature"""
|
| 441 |
+
assert isinstance(feature, BackendFeature)
|
| 442 |
+
return feature in get_backend_features(device)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def get_scheduling_for_device(device: str) -> Optional[SchedulingConstructor]:
|
| 446 |
+
return device_codegens[device].scheduling if device in device_codegens else None
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def get_wrapper_codegen_for_device(
|
| 450 |
+
device: str, cpp_wrapper: bool = False
|
| 451 |
+
) -> Optional[WrapperConstructor]:
|
| 452 |
+
if device in device_codegens:
|
| 453 |
+
wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
|
| 454 |
+
return (
|
| 455 |
+
wrapper_codegen_obj.cpp_wrapper_codegen
|
| 456 |
+
if cpp_wrapper
|
| 457 |
+
else wrapper_codegen_obj.wrapper_codegen
|
| 458 |
+
)
|
| 459 |
+
return None
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModulePass]:
|
| 463 |
+
return custom_backend_passes[device] if device in custom_backend_passes else None
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
@functools.cache
|
| 467 |
+
def init_backend_registration() -> None:
|
| 468 |
+
from .cpp import CppScheduling
|
| 469 |
+
from .cpp_wrapper_cpu import CppWrapperCpu
|
| 470 |
+
from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef
|
| 471 |
+
from .cpp_wrapper_gpu import CppWrapperGpu
|
| 472 |
+
from .cpp_wrapper_mps import CppWrapperMps
|
| 473 |
+
from .cuda_combined_scheduling import CUDACombinedScheduling
|
| 474 |
+
from .halide import HalideScheduling
|
| 475 |
+
from .mps import MetalScheduling
|
| 476 |
+
from .triton import TritonScheduling
|
| 477 |
+
from .wrapper import PythonWrapperCodegen
|
| 478 |
+
|
| 479 |
+
if get_scheduling_for_device("cpu") is None:
|
| 480 |
+
cpu_backends = {
|
| 481 |
+
"cpp": CppScheduling,
|
| 482 |
+
"halide": HalideScheduling,
|
| 483 |
+
"triton": TritonScheduling,
|
| 484 |
+
}
|
| 485 |
+
register_backend_for_device(
|
| 486 |
+
"cpu",
|
| 487 |
+
lambda scheduling: cpu_backends[config.cpu_backend](scheduling),
|
| 488 |
+
PythonWrapperCodegen,
|
| 489 |
+
CppWrapperCpuArrayRef
|
| 490 |
+
if config.aot_inductor.allow_stack_allocation
|
| 491 |
+
else CppWrapperCpu,
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
if get_scheduling_for_device("cuda") is None:
|
| 495 |
+
# CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
|
| 496 |
+
cuda_backends = {
|
| 497 |
+
"triton": CUDACombinedScheduling,
|
| 498 |
+
"halide": HalideScheduling,
|
| 499 |
+
}
|
| 500 |
+
register_backend_for_device(
|
| 501 |
+
"cuda",
|
| 502 |
+
lambda scheduling: cuda_backends[config.cuda_backend](scheduling),
|
| 503 |
+
PythonWrapperCodegen,
|
| 504 |
+
CppWrapperGpu,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
if get_scheduling_for_device("xpu") is None:
|
| 508 |
+
register_backend_for_device(
|
| 509 |
+
"xpu",
|
| 510 |
+
TritonScheduling,
|
| 511 |
+
PythonWrapperCodegen,
|
| 512 |
+
CppWrapperGpu,
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if get_scheduling_for_device("mps") is None:
|
| 516 |
+
register_backend_for_device(
|
| 517 |
+
"mps",
|
| 518 |
+
MetalScheduling,
|
| 519 |
+
PythonWrapperCodegen,
|
| 520 |
+
CppWrapperMps,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
private_backend = torch._C._get_privateuse1_backend_name()
|
| 524 |
+
if (
|
| 525 |
+
private_backend != "privateuseone"
|
| 526 |
+
and get_scheduling_for_device(private_backend) is None
|
| 527 |
+
):
|
| 528 |
+
from torch.utils.backend_registration import _get_custom_mod_func
|
| 529 |
+
|
| 530 |
+
try:
|
| 531 |
+
device_scheduling = _get_custom_mod_func("Scheduling")
|
| 532 |
+
wrapper_codegen = _get_custom_mod_func("PythonWrapperCodegen")
|
| 533 |
+
cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodegen")
|
| 534 |
+
if device_scheduling and wrapper_codegen and cpp_wrapper_codegen:
|
| 535 |
+
register_backend_for_device(
|
| 536 |
+
private_backend,
|
| 537 |
+
device_scheduling,
|
| 538 |
+
wrapper_codegen,
|
| 539 |
+
cpp_wrapper_codegen,
|
| 540 |
+
)
|
| 541 |
+
except RuntimeError:
|
| 542 |
+
pass
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def index_prevent_reordering(
|
| 546 |
+
index: Sequence[sympy.Expr],
|
| 547 |
+
index_vars: Sequence[sympy.Expr],
|
| 548 |
+
sizes: Sequence[sympy.Expr],
|
| 549 |
+
) -> list[sympy.Expr]:
|
| 550 |
+
from ..ir import FlexibleLayout
|
| 551 |
+
|
| 552 |
+
# added contiguous index prevents reordering
|
| 553 |
+
return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def register_device_op_overrides(
|
| 557 |
+
device: str, device_op_overrides: DeviceOpOverrides
|
| 558 |
+
) -> None:
|
| 559 |
+
device_op_overrides_dict[device] = device_op_overrides
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
def get_device_op_overrides(device: str) -> DeviceOpOverrides:
|
| 563 |
+
assert isinstance(device, str), type(device)
|
| 564 |
+
|
| 565 |
+
if not device_op_overrides_dict:
|
| 566 |
+
from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401
|
| 567 |
+
from .cuda import device_op_overrides # noqa: F401
|
| 568 |
+
from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
|
| 569 |
+
|
| 570 |
+
return device_op_overrides_dict[device]
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
DTYPE_TO_COMPUTATION_DTYPE: dict[torch.dtype, torch.dtype] = {
|
| 574 |
+
torch.bfloat16: torch.float,
|
| 575 |
+
torch.float16: torch.float,
|
| 576 |
+
**{
|
| 577 |
+
dtype: dtype
|
| 578 |
+
for dtype in [
|
| 579 |
+
torch.bool,
|
| 580 |
+
torch.float32,
|
| 581 |
+
torch.float64,
|
| 582 |
+
torch.int8,
|
| 583 |
+
torch.int16,
|
| 584 |
+
torch.int32,
|
| 585 |
+
torch.int64,
|
| 586 |
+
torch.uint8,
|
| 587 |
+
torch.uint16,
|
| 588 |
+
torch.uint32,
|
| 589 |
+
torch.uint64,
|
| 590 |
+
]
|
| 591 |
+
},
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def deduce_output_dtype_by_name(
|
| 596 |
+
op_name: str,
|
| 597 |
+
*args: Any,
|
| 598 |
+
**kwargs: Any,
|
| 599 |
+
) -> Optional[torch.dtype]:
|
| 600 |
+
"""
|
| 601 |
+
Given op name and a list of input dtypes, deduce the output dtype
|
| 602 |
+
"""
|
| 603 |
+
if op_name in boolean_ops():
|
| 604 |
+
return torch.bool
|
| 605 |
+
elif op_name in (
|
| 606 |
+
"to_dtype",
|
| 607 |
+
"index_expr",
|
| 608 |
+
):
|
| 609 |
+
return kwargs["dtype"] if "dtype" in kwargs else args[-1]
|
| 610 |
+
elif op_name in (
|
| 611 |
+
"rand",
|
| 612 |
+
"randn",
|
| 613 |
+
):
|
| 614 |
+
return torch.float
|
| 615 |
+
elif op_name in (
|
| 616 |
+
"get_index",
|
| 617 |
+
"randint64",
|
| 618 |
+
"load_seed",
|
| 619 |
+
):
|
| 620 |
+
return torch.int64
|
| 621 |
+
elif op_name == "reduction":
|
| 622 |
+
return kwargs["dtype"] if "dtype" in kwargs else args[1]
|
| 623 |
+
elif op_name == "constant":
|
| 624 |
+
return kwargs["dtype"] if "dtype" in kwargs else args[-1]
|
| 625 |
+
elif op_name in (
|
| 626 |
+
"load",
|
| 627 |
+
"store",
|
| 628 |
+
"store_reduction",
|
| 629 |
+
):
|
| 630 |
+
buf_name = args[1]
|
| 631 |
+
return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
|
| 632 |
+
elif op_name == "to_dtype_bitcast":
|
| 633 |
+
return kwargs["dtype"] if "dtype" in kwargs else args[-2]
|
| 634 |
+
return None
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def check_dtype(
|
| 638 |
+
buffer: IndentedBuffer, var: CSEVariableType, dtype: torch.dtype
|
| 639 |
+
) -> None:
|
| 640 |
+
backend = get_current_backend()
|
| 641 |
+
if config.test_configs.runtime_triton_dtype_assert and backend == "triton":
|
| 642 |
+
buffer.writeline(f"tl.static_assert({var}.dtype == {triton_type(dtype)})")
|
| 643 |
+
elif config.test_configs.static_cpp_dtype_assert and backend == "cpp":
|
| 644 |
+
from .cpp_utils import CppCSEVariable, DTYPE_TO_CPP
|
| 645 |
+
|
| 646 |
+
assert isinstance(var, CppCSEVariable), type(var)
|
| 647 |
+
if dtype == torch.bool:
|
| 648 |
+
if var.is_vec:
|
| 649 |
+
is_same_dt = f"IsVecMaskType<decltype({var})>::value"
|
| 650 |
+
else:
|
| 651 |
+
# operator&(bool, bool) returns int and it can be used as boolean in C++
|
| 652 |
+
is_same_dt = f"std::is_same_v<decltype({var}), bool> || std::is_same_v<decltype({var}), int>"
|
| 653 |
+
else:
|
| 654 |
+
c_var_type = f"decltype({var})"
|
| 655 |
+
if var.is_vec:
|
| 656 |
+
c_var_type = f"typename {c_var_type}::value_type"
|
| 657 |
+
is_same_dt = f"std::is_same_v<{c_var_type}, {DTYPE_TO_CPP[dtype]}>"
|
| 658 |
+
|
| 659 |
+
buffer.writeline(f"static_assert({is_same_dt});")
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
class DataTypePropagation:
|
| 663 |
+
def __init__(self, body: LoopBody) -> None:
|
| 664 |
+
self.body = body
|
| 665 |
+
self.graphs: dict[Union[Callable[..., Any], str], Any] = {
|
| 666 |
+
"root": body.root_block.graph
|
| 667 |
+
}
|
| 668 |
+
for k, v in body.subblocks.items():
|
| 669 |
+
self.graphs[k] = v.graph
|
| 670 |
+
|
| 671 |
+
def deduce_node_dtype_by_inputs(self, node: torch.fx.Node) -> Optional[torch.dtype]:
|
| 672 |
+
inputs = node.all_input_nodes
|
| 673 |
+
input_nodes = [
|
| 674 |
+
n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
|
| 675 |
+
]
|
| 676 |
+
if len(input_nodes) == 0:
|
| 677 |
+
return None
|
| 678 |
+
|
| 679 |
+
all_input_nodes_propagated = all(
|
| 680 |
+
OptimizationContext.key in n.meta
|
| 681 |
+
and n.meta[OptimizationContext.key].dtype is not None
|
| 682 |
+
for n in input_nodes
|
| 683 |
+
)
|
| 684 |
+
if not all_input_nodes_propagated:
|
| 685 |
+
return None
|
| 686 |
+
|
| 687 |
+
return functools.reduce(
|
| 688 |
+
torch.promote_types,
|
| 689 |
+
[n.meta[OptimizationContext.key].dtype for n in input_nodes],
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node) -> torch.dtype:
|
| 693 |
+
sub_graph = self.graphs[node.target]
|
| 694 |
+
dtype = self.propagate_graph(sub_graph)
|
| 695 |
+
assert dtype
|
| 696 |
+
return dtype
|
| 697 |
+
|
| 698 |
+
def deduce_node_dtype(self, node: torch.fx.Node) -> Optional[torch.dtype]:
|
| 699 |
+
if node.op == "placeholder":
|
| 700 |
+
return None
|
| 701 |
+
|
| 702 |
+
if node.target == "output" and len(node.args) != 1:
|
| 703 |
+
# we can infer output node if it only have 1 arg
|
| 704 |
+
return None
|
| 705 |
+
|
| 706 |
+
if node.target == operator.getitem:
|
| 707 |
+
node_arg = node.args[0]
|
| 708 |
+
assert isinstance(node_arg, torch.fx.Node), type(node_arg)
|
| 709 |
+
return self.deduce_node_dtype(node_arg)
|
| 710 |
+
|
| 711 |
+
assert isinstance(node.target, str), type(node.target)
|
| 712 |
+
|
| 713 |
+
if node.target.startswith("masked_subblock"):
|
| 714 |
+
return self.deduce_node_dtype_by_subgraph(node)
|
| 715 |
+
|
| 716 |
+
if (
|
| 717 |
+
output_dtype := deduce_output_dtype_by_name(
|
| 718 |
+
node.target,
|
| 719 |
+
*node.args,
|
| 720 |
+
**node.kwargs,
|
| 721 |
+
)
|
| 722 |
+
) is not None:
|
| 723 |
+
return output_dtype
|
| 724 |
+
|
| 725 |
+
return self.deduce_node_dtype_by_inputs(node)
|
| 726 |
+
|
| 727 |
+
def propagate_graph(self, graph: torch.fx.Graph) -> Optional[torch.dtype]:
|
| 728 |
+
assert graph.nodes
|
| 729 |
+
graph_dtype: Optional[torch.dtype] = None
|
| 730 |
+
# For masked_subblock, we use output's dtype to represent
|
| 731 |
+
# the dtype of this subgraph. For other cases, graph_dtype
|
| 732 |
+
# might be None
|
| 733 |
+
for node in graph.nodes:
|
| 734 |
+
if OptimizationContext.key in node.meta:
|
| 735 |
+
opt_ctx = node.meta[OptimizationContext.key]
|
| 736 |
+
else:
|
| 737 |
+
opt_ctx = OptimizationContext()
|
| 738 |
+
|
| 739 |
+
opt_ctx.dtype = self.deduce_node_dtype(node)
|
| 740 |
+
node.meta[OptimizationContext.key] = opt_ctx
|
| 741 |
+
if node.target == "output":
|
| 742 |
+
graph_dtype = opt_ctx.dtype
|
| 743 |
+
return graph_dtype
|
| 744 |
+
|
| 745 |
+
def propagate(self) -> Optional[torch.dtype]:
|
| 746 |
+
return self.propagate_graph(self.graphs["root"])
|
| 747 |
+
|
| 748 |
+
@classmethod
|
| 749 |
+
def propagate_loopbody(cls, body: LoopBody) -> Optional[torch.dtype]:
|
| 750 |
+
return cls(body).propagate()
|
| 751 |
+
|
| 752 |
+
@classmethod
|
| 753 |
+
def propagate_scheduler_node(cls, node: SchedulerNode) -> Optional[torch.dtype]:
|
| 754 |
+
from ..loop_body import LoopBody
|
| 755 |
+
from ..scheduler import SchedulerNode
|
| 756 |
+
|
| 757 |
+
assert isinstance(node, SchedulerNode), type(node)
|
| 758 |
+
assert isinstance(node._body, LoopBody), type(node._body)
|
| 759 |
+
return DataTypePropagation.propagate_loopbody(node._body)
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
class PythonPrinter(_PythonPrinter):
|
| 763 |
+
def doprint(
|
| 764 |
+
self, expr: sympy.Expr, *, simplify: bool = True, p: bool = True
|
| 765 |
+
) -> str:
|
| 766 |
+
# TODO: why are people passing strings to the printer here :think:
|
| 767 |
+
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
|
| 768 |
+
expr = V.graph.sizevars.simplify(expr)
|
| 769 |
+
return super().doprint(expr)
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
class OpDecompositions:
|
| 773 |
+
"""
|
| 774 |
+
Decomposes inductor ops
|
| 775 |
+
"""
|
| 776 |
+
|
| 777 |
+
@staticmethod
|
| 778 |
+
def identity(value: OpVarT) -> OpVarT:
|
| 779 |
+
# used to trigger cse
|
| 780 |
+
return value
|
| 781 |
+
|
| 782 |
+
@staticmethod
|
| 783 |
+
def reciprocal(x: OpVarT) -> OpVarT:
|
| 784 |
+
return ops.truediv(ops.constant(1, torch.int32), x)
|
| 785 |
+
|
| 786 |
+
@staticmethod
|
| 787 |
+
def square(x: OpVarT) -> OpVarT:
|
| 788 |
+
return ops.mul(x, x)
|
| 789 |
+
|
| 790 |
+
@staticmethod
|
| 791 |
+
def erfc(x: OpVarT) -> OpVarT:
|
| 792 |
+
return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
|
| 793 |
+
|
| 794 |
+
@staticmethod
|
| 795 |
+
def erfcx(x: OpVarT) -> OpVarT:
|
| 796 |
+
return ops.mul(ops.exp(ops.square(x)), ops.erfc(x))
|
| 797 |
+
|
| 798 |
+
@staticmethod
|
| 799 |
+
def expm1(x: OpVarT) -> OpVarT:
|
| 800 |
+
return ops.sub(ops.exp(x), ops.constant(1, torch.float32))
|
| 801 |
+
|
| 802 |
+
@staticmethod
|
| 803 |
+
def log10(x: OpVarT) -> OpVarT:
|
| 804 |
+
return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32))
|
| 805 |
+
|
| 806 |
+
@staticmethod
|
| 807 |
+
def log2(x: OpVarT) -> OpVarT:
|
| 808 |
+
return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32))
|
| 809 |
+
|
| 810 |
+
@staticmethod
|
| 811 |
+
def exp2(x: OpVarT) -> OpVarT:
|
| 812 |
+
return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32)))
|
| 813 |
+
|
| 814 |
+
@staticmethod
|
| 815 |
+
def log1p(x: OpVarT) -> OpVarT:
|
| 816 |
+
return ops.log(ops.add(x, ops.constant(1, torch.int32)))
|
| 817 |
+
|
| 818 |
+
@staticmethod
|
| 819 |
+
def sigmoid(x: OpVarT) -> OpVarT:
|
| 820 |
+
one = ops.constant(1, torch.int32)
|
| 821 |
+
return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
|
| 822 |
+
|
| 823 |
+
@staticmethod
|
| 824 |
+
def relu(x: OpVarT) -> OpVarT:
|
| 825 |
+
return ops.maximum(x, ops.constant(0, torch.int32))
|
| 826 |
+
|
| 827 |
+
@staticmethod
|
| 828 |
+
def fma(x: OpVarT, y: OpVarT, z: OpVarT) -> OpVarT:
|
| 829 |
+
# for backends that don't override this (halide)
|
| 830 |
+
return ops.add(ops.mul(x, y), z)
|
| 831 |
+
|
| 832 |
+
@staticmethod
|
| 833 |
+
def floor_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT:
|
| 834 |
+
return ops.to_dtype(ops.floor(a), dtype)
|
| 835 |
+
|
| 836 |
+
@staticmethod
|
| 837 |
+
def ceil_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT:
|
| 838 |
+
return ops.to_dtype(ops.ceil(a), dtype)
|
| 839 |
+
|
| 840 |
+
@staticmethod
|
| 841 |
+
def trunc_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT:
|
| 842 |
+
return ops.to_dtype(ops.trunc(a), dtype)
|
| 843 |
+
|
| 844 |
+
@staticmethod
|
| 845 |
+
def remainder(a: OpVarT, b: OpVarT) -> OpVarT:
|
| 846 |
+
r = ops.mod(a, b)
|
| 847 |
+
cond = ops.and_(
|
| 848 |
+
ops.ne(r, ops.constant(0, torch.int32)),
|
| 849 |
+
ops.ne(ops.signbit(r), ops.signbit(b)),
|
| 850 |
+
)
|
| 851 |
+
return ops.where(cond, ops.add(r, b), r)
|
| 852 |
+
|
| 853 |
+
@staticmethod
|
| 854 |
+
def round_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT:
|
| 855 |
+
return ops.to_dtype(ops.round(a), dtype)
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
_RE_PAREN_NOT_NEEDED = re.compile(r"[a-z0-9_.]+|\([^)]*\)|", flags=re.IGNORECASE)
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
def _all_in_parens(string: str) -> bool:
|
| 862 |
+
if string[0] != "(" or len(string) < 2:
|
| 863 |
+
return False
|
| 864 |
+
count = 1
|
| 865 |
+
for i, char in enumerate(string[1:]):
|
| 866 |
+
if char == "(":
|
| 867 |
+
count += 1
|
| 868 |
+
elif char == ")":
|
| 869 |
+
count -= 1
|
| 870 |
+
if count == 0 and i != len(string) - 2:
|
| 871 |
+
return False
|
| 872 |
+
assert count == 0
|
| 873 |
+
return True
|
| 874 |
+
|
| 875 |
+
|
| 876 |
+
class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
|
| 877 |
+
@staticmethod
|
| 878 |
+
def paren(string: OpVarT) -> OpVarT:
|
| 879 |
+
if (
|
| 880 |
+
isinstance(string, CSEVariable)
|
| 881 |
+
or _RE_PAREN_NOT_NEEDED.fullmatch(string)
|
| 882 |
+
or _all_in_parens(string)
|
| 883 |
+
):
|
| 884 |
+
# don't put extra parens for strings that are already wrapped in parens
|
| 885 |
+
return string
|
| 886 |
+
return f"({string})"
|
| 887 |
+
|
| 888 |
+
@staticmethod
|
| 889 |
+
def constant(value: Union[bool, float, int], dtype: torch.dtype) -> OpVarT:
|
| 890 |
+
return repr(value)
|
| 891 |
+
|
| 892 |
+
@staticmethod
|
| 893 |
+
def bitwise_not(x: OpVarT) -> OpVarT:
|
| 894 |
+
return f"~{OpOverrides.paren(x)}"
|
| 895 |
+
|
| 896 |
+
@staticmethod
|
| 897 |
+
def logical_not(a: OpVarT) -> OpVarT:
|
| 898 |
+
return f"{OpOverrides.paren(a)} == 0"
|
| 899 |
+
|
| 900 |
+
@staticmethod
|
| 901 |
+
def bitwise_and(x: OpVarT, y: OpVarT) -> OpVarT:
|
| 902 |
+
return f"{OpOverrides.paren(x)} & {OpOverrides.paren(y)}"
|
| 903 |
+
|
| 904 |
+
@staticmethod
|
| 905 |
+
def bitwise_or(x: OpVarT, y: OpVarT) -> OpVarT:
|
| 906 |
+
return f"{OpOverrides.paren(x)} | {OpOverrides.paren(y)}"
|
| 907 |
+
|
| 908 |
+
@staticmethod
|
| 909 |
+
def bitwise_xor(x: OpVarT, y: OpVarT) -> OpVarT:
|
| 910 |
+
return f"{OpOverrides.paren(x)} ^ {OpOverrides.paren(y)}"
|
| 911 |
+
|
| 912 |
+
@staticmethod
|
| 913 |
+
def bitwise_left_shift(x: OpVarT, y: OpVarT) -> OpVarT:
|
| 914 |
+
return f"{OpOverrides.paren(x)} << {OpOverrides.paren(y)}"
|
| 915 |
+
|
| 916 |
+
@staticmethod
|
| 917 |
+
def bitwise_right_shift(x: OpVarT, y: OpVarT) -> OpVarT:
|
| 918 |
+
return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}"
|
| 919 |
+
|
| 920 |
+
@staticmethod
|
| 921 |
+
def int_truediv(a: OpVarT, b: OpVarT) -> OpVarT:
|
| 922 |
+
# TODO: this is wrong
|
| 923 |
+
# TODO: an easy bandaid is to generate runtime asserts that it's
|
| 924 |
+
# <= 2**53, which is when this equation is correct
|
| 925 |
+
return ops.truediv(a, b)
|
| 926 |
+
|
| 927 |
+
@staticmethod
|
| 928 |
+
def load_seed(name: str, offset: OpVarT) -> OpVarT:
|
| 929 |
+
return ops.load(name, sympy.Integer(offset))
|
| 930 |
+
|
| 931 |
+
def indirect_indexing(
|
| 932 |
+
self,
|
| 933 |
+
var: OpVarT,
|
| 934 |
+
size: Union[sympy.Expr, int],
|
| 935 |
+
check: bool = True,
|
| 936 |
+
wrap_neg: bool = True,
|
| 937 |
+
) -> sympy.Symbol:
|
| 938 |
+
return sympy_index_symbol(str(var))
|
| 939 |
+
|
| 940 |
+
def check_bounds(
|
| 941 |
+
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
| 942 |
+
) -> None:
|
| 943 |
+
raise NotImplementedError(
|
| 944 |
+
f"{type(self).__name__}: check_bounds should be handled by CSEProxy"
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
def load(self, name: str, index: sympy.Expr) -> OpVarT:
|
| 948 |
+
raise NotImplementedError(
|
| 949 |
+
f"{type(self).__name__}: load should be handled by CSEProxy"
|
| 950 |
+
)
|
| 951 |
+
|
| 952 |
+
def store(
|
| 953 |
+
self, name: str, index: sympy.Expr, value: OpVarT, mode: StoreMode = None
|
| 954 |
+
) -> None:
|
| 955 |
+
raise NotImplementedError(
|
| 956 |
+
f"{type(self).__name__}: store should be handled by CSEProxy"
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
+
def store_reduction(self, name: str, index: sympy.Expr, value: OpVarT) -> None:
|
| 960 |
+
raise NotImplementedError(
|
| 961 |
+
f"{type(self).__name__}: store_reduction should be handled by CSEProxy"
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
def reduction(
|
| 965 |
+
self,
|
| 966 |
+
dtype: torch.dtype,
|
| 967 |
+
src_dtype: torch.dtype,
|
| 968 |
+
reduction_type: ReductionType,
|
| 969 |
+
value: Union[OpVarT, tuple[OpVarT, ...]],
|
| 970 |
+
) -> Union[OpVarT, tuple[OpVarT, ...]]:
|
| 971 |
+
raise NotImplementedError(
|
| 972 |
+
f"{type(self).__name__}: reduction should be handled by CSEProxy"
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
def scan(
|
| 976 |
+
self,
|
| 977 |
+
dtypes: tuple[torch.dtype, ...],
|
| 978 |
+
combine_fn: Callable[
|
| 979 |
+
[tuple[OpVarT, ...], tuple[OpVarT, ...]],
|
| 980 |
+
tuple[OpVarT, ...],
|
| 981 |
+
],
|
| 982 |
+
values: tuple[OpVarT, ...],
|
| 983 |
+
) -> tuple[OpVarT, ...]:
|
| 984 |
+
raise NotImplementedError(
|
| 985 |
+
f"{type(self).__name__}: scan should be handled by CSEProxy"
|
| 986 |
+
)
|
| 987 |
+
|
| 988 |
+
def sort(
|
| 989 |
+
self,
|
| 990 |
+
dtypes: tuple[torch.dtype, ...],
|
| 991 |
+
values: tuple[OpVarT, ...],
|
| 992 |
+
stable: bool,
|
| 993 |
+
descending: bool,
|
| 994 |
+
) -> tuple[OpVarT, ...]:
|
| 995 |
+
raise NotImplementedError(
|
| 996 |
+
f"{type(self).__name__}: sort should be handled by CSEProxy"
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
def bucketize(
|
| 1000 |
+
self,
|
| 1001 |
+
values: OpVarT,
|
| 1002 |
+
boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
|
| 1003 |
+
boundary_indices: OpVarT,
|
| 1004 |
+
indexing_dtype: torch.dtype,
|
| 1005 |
+
right: bool,
|
| 1006 |
+
sorter: Optional[tuple[str, sympy.Expr]] = None,
|
| 1007 |
+
sorter_indices: Optional[OpVarT] = None,
|
| 1008 |
+
) -> OpVarT:
|
| 1009 |
+
raise NotImplementedError(
|
| 1010 |
+
f"{type(self).__name__}: bucketize should be handled by CSEProxy"
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
def halide_clamp(self, value: OpVarT, size: sympy.Expr, check: bool) -> OpVarT:
|
| 1014 |
+
raise NotImplementedError(
|
| 1015 |
+
f"{type(self).__name__}: halide_clamp only implemented for Halide backend"
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
def inline_asm_elementwise(
|
| 1019 |
+
self,
|
| 1020 |
+
*inputs: OpVarT,
|
| 1021 |
+
asm: str,
|
| 1022 |
+
constraints: Optional[str] = None,
|
| 1023 |
+
dtype: torch.dtype = torch.float32,
|
| 1024 |
+
is_pure: bool = True,
|
| 1025 |
+
pack: int = 1,
|
| 1026 |
+
) -> OpVarT:
|
| 1027 |
+
raise NotImplementedError(
|
| 1028 |
+
f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend"
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
def output(self, *args: OpVarT) -> None:
|
| 1032 |
+
raise AssertionError(
|
| 1033 |
+
f"{type(self).__name__}: ops.output should not appear at codegen time"
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
def placeholder(self, index: int) -> OpVarT:
|
| 1037 |
+
raise AssertionError(
|
| 1038 |
+
f"{type(self).__name__}: ops.placeholder should not appear at codegen time"
|
| 1039 |
+
)
|
| 1040 |
+
|
| 1041 |
+
@staticmethod
|
| 1042 |
+
def _unimplemented(name: str) -> Callable[..., OpVarT]:
|
| 1043 |
+
def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT:
|
| 1044 |
+
raise NotImplementedError(
|
| 1045 |
+
f"{type(self).__name__} does not implement ops.{name}"
|
| 1046 |
+
)
|
| 1047 |
+
|
| 1048 |
+
unimplemented.__name__ = name
|
| 1049 |
+
unimplemented.is_unimplemented = True # type: ignore[attr-defined]
|
| 1050 |
+
return unimplemented
|
| 1051 |
+
|
| 1052 |
+
@classmethod
|
| 1053 |
+
def _is_unimplemented(cls, name: str) -> bool:
|
| 1054 |
+
fn = getattr(cls, name, None)
|
| 1055 |
+
default_fn = getattr(OpsHandler, name, None)
|
| 1056 |
+
return not fn or fn == default_fn or getattr(fn, "is_unimplemented", False)
|
| 1057 |
+
|
| 1058 |
+
@classmethod
|
| 1059 |
+
def _initialize_pointwise_overrides(cls, target: str) -> None:
|
| 1060 |
+
assert target in ("triton", "cpp", "cppvec", "halide", "mps"), target
|
| 1061 |
+
|
| 1062 |
+
for funcname, data in pointwise_overrides_data.items():
|
| 1063 |
+
impl = getattr(data, target)
|
| 1064 |
+
if impl is None:
|
| 1065 |
+
if cls._is_unimplemented(funcname):
|
| 1066 |
+
setattr(cls, funcname, cls._unimplemented(funcname))
|
| 1067 |
+
else:
|
| 1068 |
+
assert funcname not in cls.__dict__, (
|
| 1069 |
+
f"multiple definitions of {funcname} on {cls.__name__}"
|
| 1070 |
+
)
|
| 1071 |
+
impl.__name__ = funcname
|
| 1072 |
+
setattr(cls, funcname, staticmethod(impl))
|
| 1073 |
+
|
| 1074 |
+
|
| 1075 |
+
@dataclasses.dataclass
|
| 1076 |
+
class OverridesData:
|
| 1077 |
+
name: str
|
| 1078 |
+
cpp: Callable[..., str]
|
| 1079 |
+
# None when not impl in libdevice/triton
|
| 1080 |
+
triton: Optional[Callable[..., str]] = None
|
| 1081 |
+
# None when not impl in aten/.../vec
|
| 1082 |
+
cppvec: Optional[Callable[..., str]] = None
|
| 1083 |
+
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
|
| 1084 |
+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
| 1085 |
+
)
|
| 1086 |
+
halide: Optional[Callable[..., str]] = None
|
| 1087 |
+
mps: Optional[Callable[..., str]] = None
|
| 1088 |
+
|
| 1089 |
+
|
| 1090 |
+
# NB: if you add a new special function, don't forget to update
|
| 1091 |
+
# torch._inductor.ops_handler too
|
| 1092 |
+
pointwise_overrides_data: dict[str, OverridesData] = dict(
|
| 1093 |
+
airy_ai=OverridesData(
|
| 1094 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1095 |
+
cpp=lambda x: f"airy_ai_forward({x})",
|
| 1096 |
+
name="special_airy_ai",
|
| 1097 |
+
),
|
| 1098 |
+
bessel_j0=OverridesData(
|
| 1099 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1100 |
+
cpp=lambda x: f"bessel_j0_forward({x})",
|
| 1101 |
+
triton=lambda x: f"libdevice.j0({x})",
|
| 1102 |
+
name="special_bessel_j0",
|
| 1103 |
+
),
|
| 1104 |
+
bessel_j1=OverridesData(
|
| 1105 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1106 |
+
cpp=lambda x: f"bessel_j1_forward({x})",
|
| 1107 |
+
triton=lambda x: f"libdevice.j1({x})",
|
| 1108 |
+
name="special_bessel_j1",
|
| 1109 |
+
),
|
| 1110 |
+
bessel_y0=OverridesData(
|
| 1111 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1112 |
+
cpp=lambda x: f"bessel_y0_forward({x})",
|
| 1113 |
+
triton=lambda x: f"libdevice.y0({x})",
|
| 1114 |
+
name="special_bessel_y0",
|
| 1115 |
+
),
|
| 1116 |
+
bessel_y1=OverridesData(
|
| 1117 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1118 |
+
cpp=lambda x: f"bessel_y1_forward({x})",
|
| 1119 |
+
triton=lambda x: f"libdevice.y1({x})",
|
| 1120 |
+
name="special_bessel_y1",
|
| 1121 |
+
),
|
| 1122 |
+
digamma=OverridesData(
|
| 1123 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1124 |
+
cpp=lambda x: f"calc_digamma({x})",
|
| 1125 |
+
cppvec=lambda x: f"{x}.digamma()",
|
| 1126 |
+
name="digamma",
|
| 1127 |
+
),
|
| 1128 |
+
# no cpp nor triton implementation for entr, it is defined as decomposition
|
| 1129 |
+
# erf, erfc
|
| 1130 |
+
erfcx=OverridesData(
|
| 1131 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1132 |
+
cpp=lambda x: f"calc_erfcx({x})",
|
| 1133 |
+
triton=lambda x: f"libdevice.erfcx({x})",
|
| 1134 |
+
name="special_erfcx",
|
| 1135 |
+
),
|
| 1136 |
+
fma=OverridesData(
|
| 1137 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1138 |
+
cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})",
|
| 1139 |
+
cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})",
|
| 1140 |
+
triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})",
|
| 1141 |
+
name="fma",
|
| 1142 |
+
),
|
| 1143 |
+
# erfinv, exp2, expit, gammaln
|
| 1144 |
+
igamma=OverridesData(
|
| 1145 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1146 |
+
cpp=lambda x, y: f"calc_igamma({x}, {y})",
|
| 1147 |
+
name="igamma",
|
| 1148 |
+
),
|
| 1149 |
+
igammac=OverridesData(
|
| 1150 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1151 |
+
cpp=lambda x, y: f"calc_igammac({x}, {y})",
|
| 1152 |
+
name="igammac",
|
| 1153 |
+
),
|
| 1154 |
+
gammainc=OverridesData(
|
| 1155 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1156 |
+
cpp=lambda x, y: f"calc_igamma({x}, {y})",
|
| 1157 |
+
name="special_gammainc",
|
| 1158 |
+
),
|
| 1159 |
+
gammaincc=OverridesData(
|
| 1160 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1161 |
+
cpp=lambda x, y: f"calc_igammac({x}, {y})",
|
| 1162 |
+
name="special_gammaincc",
|
| 1163 |
+
),
|
| 1164 |
+
i0=OverridesData(
|
| 1165 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1166 |
+
cpp=lambda x: f"calc_i0({x})",
|
| 1167 |
+
triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
|
| 1168 |
+
cppvec=lambda x: f"{x}.i0()",
|
| 1169 |
+
name="i0",
|
| 1170 |
+
),
|
| 1171 |
+
i0e=OverridesData(
|
| 1172 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1173 |
+
cpp=lambda x: f"calc_i0e({x})",
|
| 1174 |
+
cppvec=lambda x: f"{x}.i0e()",
|
| 1175 |
+
name="special_i0e",
|
| 1176 |
+
),
|
| 1177 |
+
i1=OverridesData(
|
| 1178 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1179 |
+
cpp=lambda x: f"calc_i1({x})",
|
| 1180 |
+
triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
|
| 1181 |
+
name="special_i1",
|
| 1182 |
+
),
|
| 1183 |
+
i1e=OverridesData(
|
| 1184 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1185 |
+
cpp=lambda x: f"calc_i1e({x})",
|
| 1186 |
+
name="special_i1e",
|
| 1187 |
+
),
|
| 1188 |
+
log_ndtr=OverridesData(
|
| 1189 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1190 |
+
cpp=lambda x: f"calc_log_ndtr({x})",
|
| 1191 |
+
name="special_log_ndtr",
|
| 1192 |
+
),
|
| 1193 |
+
# logit
|
| 1194 |
+
modified_bessel_i0=OverridesData(
|
| 1195 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1196 |
+
cpp=lambda x: f"modified_bessel_i0_forward({x})",
|
| 1197 |
+
triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
|
| 1198 |
+
name="special_modified_bessel_i0",
|
| 1199 |
+
),
|
| 1200 |
+
modified_bessel_i1=OverridesData(
|
| 1201 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1202 |
+
cpp=lambda x: f"modified_bessel_i1_forward({x})",
|
| 1203 |
+
triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
|
| 1204 |
+
name="special_modified_bessel_i1",
|
| 1205 |
+
),
|
| 1206 |
+
modified_bessel_k0=OverridesData(
|
| 1207 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1208 |
+
cpp=lambda x: f"modified_bessel_k0_forward({x})",
|
| 1209 |
+
name="special_modified_bessel_k0",
|
| 1210 |
+
),
|
| 1211 |
+
modified_bessel_k1=OverridesData(
|
| 1212 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1213 |
+
cpp=lambda x: f"modified_bessel_k1_forward({x})",
|
| 1214 |
+
name="special_modified_bessel_k1",
|
| 1215 |
+
),
|
| 1216 |
+
# multigamma
|
| 1217 |
+
ndtr=OverridesData(
|
| 1218 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1219 |
+
cpp=lambda x: f"calc_ndtr({x})",
|
| 1220 |
+
name="special_ndtr",
|
| 1221 |
+
),
|
| 1222 |
+
ndtri=OverridesData(
|
| 1223 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1224 |
+
cpp=lambda x: f"calc_ndtri({x})",
|
| 1225 |
+
name="special_ndtri",
|
| 1226 |
+
),
|
| 1227 |
+
polygamma=OverridesData(
|
| 1228 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1229 |
+
cpp=lambda x,
|
| 1230 |
+
y: f"{x} == 0 ? calc_digamma({y}) : ({x} == 1 ? trigamma({y}) : calc_polygamma({y}, {x}))",
|
| 1231 |
+
name="polygamma",
|
| 1232 |
+
),
|
| 1233 |
+
# psi - alias to digamma
|
| 1234 |
+
# round
|
| 1235 |
+
scaled_modified_bessel_k0=OverridesData(
|
| 1236 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1237 |
+
cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})",
|
| 1238 |
+
name="special_scaled_modified_bessel_k0",
|
| 1239 |
+
),
|
| 1240 |
+
scaled_modified_bessel_k1=OverridesData(
|
| 1241 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1242 |
+
cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})",
|
| 1243 |
+
name="special_scaled_modified_bessel_k1",
|
| 1244 |
+
),
|
| 1245 |
+
# sinc
|
| 1246 |
+
spherical_bessel_j0=OverridesData(
|
| 1247 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1248 |
+
cpp=lambda x: f"spherical_bessel_j0_forward({x})",
|
| 1249 |
+
name="special_spherical_bessel_j0",
|
| 1250 |
+
),
|
| 1251 |
+
zeta=OverridesData(
|
| 1252 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1253 |
+
cpp=lambda x, y: f"zeta({x}, {y})",
|
| 1254 |
+
name="special_zeta",
|
| 1255 |
+
),
|
| 1256 |
+
chebyshev_polynomial_t=OverridesData(
|
| 1257 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1258 |
+
cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})",
|
| 1259 |
+
name="special_chebyshev_polynomial_t",
|
| 1260 |
+
),
|
| 1261 |
+
chebyshev_polynomial_u=OverridesData(
|
| 1262 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1263 |
+
cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})",
|
| 1264 |
+
name="special_chebyshev_polynomial_u",
|
| 1265 |
+
),
|
| 1266 |
+
chebyshev_polynomial_v=OverridesData(
|
| 1267 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1268 |
+
cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})",
|
| 1269 |
+
name="special_chebyshev_polynomial_v",
|
| 1270 |
+
),
|
| 1271 |
+
chebyshev_polynomial_w=OverridesData(
|
| 1272 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1273 |
+
cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})",
|
| 1274 |
+
name="special_chebyshev_polynomial_w",
|
| 1275 |
+
),
|
| 1276 |
+
legendre_polynomial_p=OverridesData(
|
| 1277 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1278 |
+
cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})",
|
| 1279 |
+
name="special_legendre_polynomial_p",
|
| 1280 |
+
),
|
| 1281 |
+
shifted_chebyshev_polynomial_t=OverridesData(
|
| 1282 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1283 |
+
cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})",
|
| 1284 |
+
name="special_shifted_chebyshev_polynomial_t",
|
| 1285 |
+
),
|
| 1286 |
+
shifted_chebyshev_polynomial_u=OverridesData(
|
| 1287 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1288 |
+
cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})",
|
| 1289 |
+
name="special_shifted_chebyshev_polynomial_u",
|
| 1290 |
+
),
|
| 1291 |
+
shifted_chebyshev_polynomial_v=OverridesData(
|
| 1292 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1293 |
+
cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})",
|
| 1294 |
+
name="special_shifted_chebyshev_polynomial_v",
|
| 1295 |
+
),
|
| 1296 |
+
shifted_chebyshev_polynomial_w=OverridesData(
|
| 1297 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1298 |
+
cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})",
|
| 1299 |
+
name="special_shifted_chebyshev_polynomial_w",
|
| 1300 |
+
),
|
| 1301 |
+
hermite_polynomial_h=OverridesData(
|
| 1302 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1303 |
+
cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})",
|
| 1304 |
+
name="special_hermite_polynomial_h",
|
| 1305 |
+
),
|
| 1306 |
+
hermite_polynomial_he=OverridesData(
|
| 1307 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1308 |
+
cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})",
|
| 1309 |
+
name="special_hermite_polynomial_he",
|
| 1310 |
+
),
|
| 1311 |
+
laguerre_polynomial_l=OverridesData(
|
| 1312 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1313 |
+
cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})",
|
| 1314 |
+
name="special_laguerre_polynomial_l",
|
| 1315 |
+
),
|
| 1316 |
+
)
|
| 1317 |
+
|
| 1318 |
+
|
| 1319 |
+
def is_buffer_removed(name: str) -> bool:
|
| 1320 |
+
return any(
|
| 1321 |
+
name in x
|
| 1322 |
+
for x in (
|
| 1323 |
+
V.graph.removed_buffers,
|
| 1324 |
+
V.kernel.removed_buffers,
|
| 1325 |
+
V.graph.inplaced_to_remove,
|
| 1326 |
+
V.kernel.inplaced_to_remove,
|
| 1327 |
+
)
|
| 1328 |
+
)
|
| 1329 |
+
|
| 1330 |
+
|
| 1331 |
+
class DeferredLine(DeferredLineBase):
|
| 1332 |
+
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
|
| 1333 |
+
|
| 1334 |
+
def __init__(self, name: str, line: str):
|
| 1335 |
+
super().__init__(line)
|
| 1336 |
+
self.name = name
|
| 1337 |
+
assert not isinstance(line, DeferredLineBase)
|
| 1338 |
+
|
| 1339 |
+
def __call__(self) -> Optional[str]:
|
| 1340 |
+
if not is_buffer_removed(self.name):
|
| 1341 |
+
return self.line
|
| 1342 |
+
return None
|
| 1343 |
+
|
| 1344 |
+
def _new_line(self, line: str) -> DeferredLine:
|
| 1345 |
+
return DeferredLine(self.name, line)
|
| 1346 |
+
|
| 1347 |
+
|
| 1348 |
+
class BracesBuffer(IndentedBuffer):
|
| 1349 |
+
def indent(self, offset: int = 1) -> contextlib.AbstractContextManager[None]:
|
| 1350 |
+
@contextlib.contextmanager
|
| 1351 |
+
def ctx() -> Iterator[None]:
|
| 1352 |
+
for _ in range(offset):
|
| 1353 |
+
self.writeline("{")
|
| 1354 |
+
self._indent += 1
|
| 1355 |
+
for _ in range(-offset):
|
| 1356 |
+
self._indent -= 1
|
| 1357 |
+
self.writeline("}")
|
| 1358 |
+
yield
|
| 1359 |
+
for _ in range(-offset):
|
| 1360 |
+
self.writeline("{")
|
| 1361 |
+
self._indent += 1
|
| 1362 |
+
for _ in range(offset):
|
| 1363 |
+
self._indent -= 1
|
| 1364 |
+
self.writeline("}")
|
| 1365 |
+
|
| 1366 |
+
return ctx()
|
| 1367 |
+
|
| 1368 |
+
|
| 1369 |
+
class InplacedBuffer(NamedTuple):
|
| 1370 |
+
inner_name: str
|
| 1371 |
+
other_names: list[str]
|
| 1372 |
+
|
| 1373 |
+
|
| 1374 |
+
@dataclasses.dataclass
|
| 1375 |
+
class ArgName:
|
| 1376 |
+
name: str
|
| 1377 |
+
# is_constexpr=True is used to attach a " : tl.constexpr" into the argument list
|
| 1378 |
+
is_constexpr: bool = False
|
| 1379 |
+
|
| 1380 |
+
def full_name(self) -> str:
|
| 1381 |
+
return f"{self.name}{' : tl.constexpr' if self.is_constexpr else ''}"
|
| 1382 |
+
|
| 1383 |
+
|
| 1384 |
+
class RemovedArg:
|
| 1385 |
+
def __str__(self) -> str:
|
| 1386 |
+
return "REMOVED"
|
| 1387 |
+
|
| 1388 |
+
|
| 1389 |
+
REMOVED = RemovedArg()
|
| 1390 |
+
|
| 1391 |
+
|
| 1392 |
+
class KernelArgs:
|
| 1393 |
+
@staticmethod
|
| 1394 |
+
def _lookup(
|
| 1395 |
+
prefix: str,
|
| 1396 |
+
odict: Union[dict[_T, Union[str, RemovedArg]], dict[_T, str]],
|
| 1397 |
+
name: _T,
|
| 1398 |
+
) -> str:
|
| 1399 |
+
result: Union[str, RemovedArg] = odict.get(name, REMOVED)
|
| 1400 |
+
if isinstance(result, RemovedArg):
|
| 1401 |
+
odict[name] = new_result = f"{prefix}{len(odict)}"
|
| 1402 |
+
return new_result
|
| 1403 |
+
return result
|
| 1404 |
+
|
| 1405 |
+
def __init__(self) -> None:
|
| 1406 |
+
self.input_buffers: dict[str, str] = {}
|
| 1407 |
+
self.output_buffers: dict[str, Union[str, RemovedArg]] = {}
|
| 1408 |
+
self.inplace_buffers: dict[str, Union[InplacedBuffer, RemovedArg]] = {}
|
| 1409 |
+
self.sizevars: dict[sympy.Expr, str] = {}
|
| 1410 |
+
self.workspace_args: list[WorkspaceArg] = []
|
| 1411 |
+
|
| 1412 |
+
def __repr__(self) -> str:
|
| 1413 |
+
return "KernelArgs({})".format(
|
| 1414 |
+
", ".join(
|
| 1415 |
+
map(
|
| 1416 |
+
repr,
|
| 1417 |
+
[
|
| 1418 |
+
self.input_buffers,
|
| 1419 |
+
self.output_buffers,
|
| 1420 |
+
self.inplace_buffers,
|
| 1421 |
+
self.sizevars,
|
| 1422 |
+
],
|
| 1423 |
+
)
|
| 1424 |
+
)
|
| 1425 |
+
)
|
| 1426 |
+
|
| 1427 |
+
@staticmethod
|
| 1428 |
+
def _buffer_is_marked_removed(name: Any) -> bool:
|
| 1429 |
+
# this function is needed by MTIA
|
| 1430 |
+
return isinstance(name, RemovedArg)
|
| 1431 |
+
|
| 1432 |
+
def input(self, name: str) -> str:
|
| 1433 |
+
if V.graph.scheduler:
|
| 1434 |
+
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
| 1435 |
+
assert name not in V.graph.removed_buffers, name
|
| 1436 |
+
if name in self.output_buffers:
|
| 1437 |
+
return cast(str, self.output_buffers[name])
|
| 1438 |
+
if name in self.inplace_buffers:
|
| 1439 |
+
return cast(InplacedBuffer, self.inplace_buffers[name]).inner_name
|
| 1440 |
+
if name.startswith("seed"):
|
| 1441 |
+
return self._lookup("seed", self.input_buffers, name)
|
| 1442 |
+
return self._lookup("in_ptr", self.input_buffers, name)
|
| 1443 |
+
|
| 1444 |
+
def output(self, name: str) -> str:
|
| 1445 |
+
if V.graph.scheduler:
|
| 1446 |
+
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
| 1447 |
+
assert name not in V.graph.removed_buffers, name
|
| 1448 |
+
if name in self.inplace_buffers:
|
| 1449 |
+
return cast(InplacedBuffer, self.inplace_buffers[name]).inner_name
|
| 1450 |
+
return self._lookup("out_ptr", self.output_buffers, name)
|
| 1451 |
+
|
| 1452 |
+
def make_inplace(self, input_name: str, output_name: str) -> None:
|
| 1453 |
+
if input_name in V.graph.unaligned_buffers:
|
| 1454 |
+
V.graph.unaligned_buffers.add(output_name)
|
| 1455 |
+
assert output_name not in self.inplace_buffers, output_name
|
| 1456 |
+
if input_name in self.inplace_buffers:
|
| 1457 |
+
buf = self.inplace_buffers[input_name]
|
| 1458 |
+
assert not isinstance(buf, RemovedArg)
|
| 1459 |
+
buf.other_names.append(output_name)
|
| 1460 |
+
self.inplace_buffers[output_name] = buf
|
| 1461 |
+
else:
|
| 1462 |
+
alive_buffers = [
|
| 1463 |
+
val
|
| 1464 |
+
for val in self.inplace_buffers.values()
|
| 1465 |
+
if not isinstance(val, RemovedArg)
|
| 1466 |
+
]
|
| 1467 |
+
removed_buffers = [
|
| 1468 |
+
val
|
| 1469 |
+
for val in self.inplace_buffers.values()
|
| 1470 |
+
if isinstance(val, RemovedArg)
|
| 1471 |
+
]
|
| 1472 |
+
inplace_buffer_idx = len(unique(alive_buffers)) + len(removed_buffers)
|
| 1473 |
+
buf = InplacedBuffer(
|
| 1474 |
+
f"in_out_ptr{inplace_buffer_idx}",
|
| 1475 |
+
[input_name, output_name],
|
| 1476 |
+
)
|
| 1477 |
+
self.inplace_buffers[input_name] = buf
|
| 1478 |
+
self.inplace_buffers[output_name] = buf
|
| 1479 |
+
|
| 1480 |
+
def workspace(self, nbytes: sympy.Expr, zero_fill: bool) -> tuple[str, int]:
|
| 1481 |
+
"""
|
| 1482 |
+
Allocate or extend a workspace buffer of nbytes bytes.
|
| 1483 |
+
|
| 1484 |
+
This function manages the allocation of a workspace buffer. It either creates
|
| 1485 |
+
a new WorkspaceArg or extends an existing one.
|
| 1486 |
+
|
| 1487 |
+
Note:
|
| 1488 |
+
- Calling this function will in-place mutate the args by adding or updating
|
| 1489 |
+
a WorkspaceArg.
|
| 1490 |
+
- The codegen for generating the Python argdefs and call_defs will check
|
| 1491 |
+
this field and allocate the buffer accordingly.
|
| 1492 |
+
- A new argument "ws_ptr" will be present in the generated code.
|
| 1493 |
+
|
| 1494 |
+
Args:
|
| 1495 |
+
nbytes (sympy.Expr): The number of bytes to allocate.
|
| 1496 |
+
zero_fill (bool): Whether to initialize the buffer to zero.
|
| 1497 |
+
|
| 1498 |
+
Returns:
|
| 1499 |
+
Tuple[str, int]: A tuple containing:
|
| 1500 |
+
- "ws_ptr": A string identifier for the workspace pointer.
|
| 1501 |
+
- offset: An integer representing the byte offset in the workspace.
|
| 1502 |
+
"""
|
| 1503 |
+
arg = WorkspaceArg(
|
| 1504 |
+
count=nbytes,
|
| 1505 |
+
zero_mode=WorkspaceZeroMode.from_bool(zero_fill),
|
| 1506 |
+
device=V.graph.get_current_device_or_throw(),
|
| 1507 |
+
outer_name=WorkspaceArg.unique_name(),
|
| 1508 |
+
)
|
| 1509 |
+
for i, existing_arg in enumerate(self.workspace_args):
|
| 1510 |
+
if WorkspaceArg.can_join(existing_arg, arg):
|
| 1511 |
+
offset = existing_arg.count
|
| 1512 |
+
self.workspace_args[i] = WorkspaceArg.join(existing_arg, arg)
|
| 1513 |
+
return existing_arg.inner_name, offset
|
| 1514 |
+
assert (
|
| 1515 |
+
existing_arg.inner_name != arg.inner_name
|
| 1516 |
+
and existing_arg.outer_name != arg.outer_name
|
| 1517 |
+
), existing_arg
|
| 1518 |
+
self.workspace_args.append(arg)
|
| 1519 |
+
return arg.inner_name, 0
|
| 1520 |
+
|
| 1521 |
+
def semaphores(self, min_size: sympy.Expr) -> str:
|
| 1522 |
+
"""
|
| 1523 |
+
Lazily allocate a graph-wide semaphores buffer with at least min_size. This is a single buffer shared by
|
| 1524 |
+
all kernels and zero initialized once at graph start. Each kernel must leave the buffer zeroed on exit.
|
| 1525 |
+
|
| 1526 |
+
Warning: multiple calls to this function will return the same buffer.
|
| 1527 |
+
|
| 1528 |
+
Args:
|
| 1529 |
+
min_size: the number of int32 semaphores required
|
| 1530 |
+
|
| 1531 |
+
Returns:
|
| 1532 |
+
name of the semaphores buffer
|
| 1533 |
+
"""
|
| 1534 |
+
current_device = V.graph.get_current_device_or_throw()
|
| 1535 |
+
arg = WorkspaceArg(
|
| 1536 |
+
count=min_size,
|
| 1537 |
+
zero_mode=WorkspaceZeroMode.ZERO_PER_GRAPH,
|
| 1538 |
+
dtype=torch.uint32,
|
| 1539 |
+
inner_name="sem_ptr",
|
| 1540 |
+
outer_name=f"semaphores_{current_device.type}_{current_device.index}",
|
| 1541 |
+
device=current_device,
|
| 1542 |
+
)
|
| 1543 |
+
for existing_arg in self.workspace_args:
|
| 1544 |
+
if existing_arg.inner_name == arg.inner_name:
|
| 1545 |
+
assert arg == existing_arg, (arg, existing_arg)
|
| 1546 |
+
self.workspace_args.append(arg)
|
| 1547 |
+
return arg.inner_name
|
| 1548 |
+
|
| 1549 |
+
def seed_offset(self, name: str, value: int) -> str:
|
| 1550 |
+
assert isinstance(value, int), (type(value), value)
|
| 1551 |
+
# here we are lifting a constant integer into an arg to the kernel to try to get additional cache hits
|
| 1552 |
+
value = sympy.Integer(value)
|
| 1553 |
+
if value in self.sizevars:
|
| 1554 |
+
return self.sizevars[value]
|
| 1555 |
+
if name in self.sizevars.values():
|
| 1556 |
+
name = (
|
| 1557 |
+
f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
|
| 1558 |
+
)
|
| 1559 |
+
self.sizevars[value] = name
|
| 1560 |
+
return name
|
| 1561 |
+
|
| 1562 |
+
def size(self, name: sympy.Symbol) -> str:
|
| 1563 |
+
assert isinstance(name, sympy.Symbol), (type(name), name)
|
| 1564 |
+
if name.name == "seed":
|
| 1565 |
+
self.sizevars[name] = "seed" # don't manage the name of seeds
|
| 1566 |
+
return "seed"
|
| 1567 |
+
return self._lookup("ks", self.sizevars, name)
|
| 1568 |
+
|
| 1569 |
+
def call_names(self) -> Iterator[str]:
|
| 1570 |
+
return chain(
|
| 1571 |
+
self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
|
| 1572 |
+
)
|
| 1573 |
+
|
| 1574 |
+
def arg_name(self, name: str) -> Optional[str]:
|
| 1575 |
+
"""
|
| 1576 |
+
Returns inner name of a given outer name.
|
| 1577 |
+
"""
|
| 1578 |
+
inplaced = self.inplace_buffers.get(name, None)
|
| 1579 |
+
if inplaced is not None and not isinstance(inplaced, RemovedArg):
|
| 1580 |
+
return inplaced.inner_name
|
| 1581 |
+
output_name = self.output_buffers.get(name, None)
|
| 1582 |
+
if output_name is not None and not isinstance(output_name, RemovedArg):
|
| 1583 |
+
return output_name
|
| 1584 |
+
return self.input_buffers.get(name, None)
|
| 1585 |
+
|
| 1586 |
+
def wrap_ptr_arg(self, buf: str, dtype: torch.dtype) -> str:
|
| 1587 |
+
return buf
|
| 1588 |
+
|
| 1589 |
+
def wrap_size_arg(self, size: SymbolLike) -> str:
|
| 1590 |
+
return str(size)
|
| 1591 |
+
|
| 1592 |
+
def cpp_argdefs(
|
| 1593 |
+
self, dtype_to_cpp_type: Optional[dict[torch.dtype, str]] = None
|
| 1594 |
+
) -> tuple[list[str], list[str], list[str]]:
|
| 1595 |
+
from .cpp_utils import INDEX_TYPE
|
| 1596 |
+
|
| 1597 |
+
if dtype_to_cpp_type is None:
|
| 1598 |
+
from .cpp_utils import DTYPE_TO_CPP
|
| 1599 |
+
|
| 1600 |
+
dtype_to_cpp_type = DTYPE_TO_CPP
|
| 1601 |
+
|
| 1602 |
+
call_args = []
|
| 1603 |
+
arg_defs = []
|
| 1604 |
+
arg_types = []
|
| 1605 |
+
for inplaced in unique(self.inplace_buffers.values()):
|
| 1606 |
+
if isinstance(inplaced, RemovedArg):
|
| 1607 |
+
continue
|
| 1608 |
+
outer = inplaced.other_names[-1]
|
| 1609 |
+
inner = inplaced.inner_name
|
| 1610 |
+
dtype = V.graph.get_dtype(outer)
|
| 1611 |
+
cpp_dtype = dtype_to_cpp_type[dtype]
|
| 1612 |
+
arg_defs.append(f"{cpp_dtype}* {inner}")
|
| 1613 |
+
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
| 1614 |
+
arg_types.append(f"{cpp_dtype}*")
|
| 1615 |
+
for outer, inner in self.input_buffers.items():
|
| 1616 |
+
if outer in self.inplace_buffers:
|
| 1617 |
+
continue
|
| 1618 |
+
dtype = V.graph.get_dtype(outer)
|
| 1619 |
+
cpp_dtype = dtype_to_cpp_type[dtype]
|
| 1620 |
+
arg_defs.append(f"const {cpp_dtype}* {inner}")
|
| 1621 |
+
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
| 1622 |
+
arg_types.append(f"const {cpp_dtype}*")
|
| 1623 |
+
for outer, maybe_inner in self.output_buffers.items():
|
| 1624 |
+
if outer in self.inplace_buffers or isinstance(maybe_inner, RemovedArg):
|
| 1625 |
+
continue
|
| 1626 |
+
dtype = V.graph.get_dtype(outer)
|
| 1627 |
+
cpp_dtype = dtype_to_cpp_type[dtype]
|
| 1628 |
+
arg_defs.append(f"{cpp_dtype}* {maybe_inner}")
|
| 1629 |
+
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
| 1630 |
+
arg_types.append(f"{cpp_dtype}*")
|
| 1631 |
+
for outer, inner in self.sizevars.items():
|
| 1632 |
+
arg_defs.append(f"const {INDEX_TYPE} {inner}")
|
| 1633 |
+
call_args.append(self.wrap_size_arg(outer))
|
| 1634 |
+
arg_types.append(f"const {INDEX_TYPE}")
|
| 1635 |
+
if V.graph.wrapper_code:
|
| 1636 |
+
V.graph.wrapper_code.ensure_size_computed(outer)
|
| 1637 |
+
assert not self.workspace_args, "Workspace not supported on CPU "
|
| 1638 |
+
return arg_defs, call_args, arg_types
|
| 1639 |
+
|
| 1640 |
+
def python_argdefs(
|
| 1641 |
+
self,
|
| 1642 |
+
) -> tuple[list[ArgName], list[str], list[KernelArgType], list[Any]]:
|
| 1643 |
+
arg_defs: list[ArgName] = []
|
| 1644 |
+
call_args: list[str] = []
|
| 1645 |
+
arg_types: list[Any] = []
|
| 1646 |
+
precompile_args: list[KernelArgType] = []
|
| 1647 |
+
for inplaced in unique(self.inplace_buffers.values()):
|
| 1648 |
+
if isinstance(inplaced, RemovedArg):
|
| 1649 |
+
continue
|
| 1650 |
+
arg_defs.append(ArgName(inplaced.inner_name))
|
| 1651 |
+
call_args.append(inplaced.other_names[-1])
|
| 1652 |
+
arg_types.append(V.graph.get_dtype(inplaced.other_names[-1]))
|
| 1653 |
+
precompile_args.append(
|
| 1654 |
+
TensorArg(
|
| 1655 |
+
name=inplaced.inner_name,
|
| 1656 |
+
buffer=inplaced.other_names[-1],
|
| 1657 |
+
dtype=V.graph.get_dtype(inplaced.other_names[-1]),
|
| 1658 |
+
)
|
| 1659 |
+
)
|
| 1660 |
+
for outer, inner in chain(
|
| 1661 |
+
self.input_buffers.items(), self.output_buffers.items()
|
| 1662 |
+
):
|
| 1663 |
+
if outer in self.inplace_buffers or isinstance(inner, RemovedArg):
|
| 1664 |
+
continue
|
| 1665 |
+
arg_defs.append(ArgName(inner))
|
| 1666 |
+
call_args.append(outer)
|
| 1667 |
+
arg_types.append(V.graph.get_dtype(outer))
|
| 1668 |
+
precompile_args.append(
|
| 1669 |
+
TensorArg(
|
| 1670 |
+
name=inner,
|
| 1671 |
+
buffer=outer,
|
| 1672 |
+
dtype=V.graph.get_dtype(outer),
|
| 1673 |
+
)
|
| 1674 |
+
)
|
| 1675 |
+
for outer, inner in self.sizevars.items():
|
| 1676 |
+
arg_defs.append(ArgName(inner))
|
| 1677 |
+
call_args.append(outer)
|
| 1678 |
+
arg_types.append(type(outer))
|
| 1679 |
+
precompile_args.append(SizeArg(inner, outer))
|
| 1680 |
+
if V.graph.wrapper_code:
|
| 1681 |
+
V.graph.wrapper_code.ensure_size_computed(outer)
|
| 1682 |
+
for arg in self.workspace_args:
|
| 1683 |
+
arg_defs.append(ArgName(arg.inner_name))
|
| 1684 |
+
call_args.append(arg.outer_name)
|
| 1685 |
+
precompile_args.append(arg)
|
| 1686 |
+
arg_types.append(arg.dtype)
|
| 1687 |
+
return arg_defs, call_args, precompile_args, arg_types
|
| 1688 |
+
|
| 1689 |
+
def aliases(self) -> Iterator[tuple[str, str]]:
|
| 1690 |
+
for inplaced in unique(self.inplace_buffers.values()):
|
| 1691 |
+
if isinstance(inplaced, RemovedArg):
|
| 1692 |
+
continue
|
| 1693 |
+
for other in inplaced.other_names:
|
| 1694 |
+
if (
|
| 1695 |
+
other in V.graph.inplaced_to_remove
|
| 1696 |
+
or other in V.kernel.inplaced_to_remove
|
| 1697 |
+
):
|
| 1698 |
+
continue
|
| 1699 |
+
if other in self.input_buffers:
|
| 1700 |
+
yield self.input_buffers[other], inplaced.inner_name
|
| 1701 |
+
if other in self.output_buffers:
|
| 1702 |
+
yield cast(str, self.output_buffers[other]), inplaced.inner_name
|
| 1703 |
+
|
| 1704 |
+
def is_removed(self, name: str) -> bool:
|
| 1705 |
+
return isinstance(
|
| 1706 |
+
self.output_buffers.get(name, REMOVED), RemovedArg
|
| 1707 |
+
) and isinstance(self.inplace_buffers.get(name, REMOVED), RemovedArg)
|
| 1708 |
+
|
| 1709 |
+
# Includes inplace buffers, excludes removed buffers. Essentially,
|
| 1710 |
+
# after you do a call into this kernel, which buffers actually contain
|
| 1711 |
+
# updated data? Modeled off of python_argdefs.
|
| 1712 |
+
def live_output_buffers(self) -> OrderedSet[str]:
|
| 1713 |
+
live_outs: OrderedSet[str] = OrderedSet()
|
| 1714 |
+
for inplaced in unique(self.inplace_buffers.values()):
|
| 1715 |
+
if isinstance(inplaced, RemovedArg):
|
| 1716 |
+
continue
|
| 1717 |
+
live_outs.add(inplaced.other_names[-1])
|
| 1718 |
+
for outer, inner in self.output_buffers.items():
|
| 1719 |
+
if outer in self.inplace_buffers or isinstance(inner, RemovedArg):
|
| 1720 |
+
continue
|
| 1721 |
+
live_outs.add(outer)
|
| 1722 |
+
return live_outs
|
| 1723 |
+
|
| 1724 |
+
|
| 1725 |
+
class CSEVariable:
|
| 1726 |
+
"""A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
|
| 1727 |
+
To do so, the backends can simply overload `Kernel.create_cse_var`
|
| 1728 |
+
The "CSEVariable.update_on_args" method gives you a hook for annotations
|
| 1729 |
+
See example of TritonCSEVariable in triton.py
|
| 1730 |
+
"""
|
| 1731 |
+
|
| 1732 |
+
def __init__(
|
| 1733 |
+
self,
|
| 1734 |
+
name: str,
|
| 1735 |
+
bounds: ValueRanges[Any],
|
| 1736 |
+
dtype: Optional[torch.dtype] = None,
|
| 1737 |
+
):
|
| 1738 |
+
super().__init__()
|
| 1739 |
+
assert isinstance(bounds, ValueRanges), type(bounds)
|
| 1740 |
+
self.name = name
|
| 1741 |
+
self.bounds = bounds
|
| 1742 |
+
self.use_count = 1 # track how many times this expression is used
|
| 1743 |
+
self.dtype = dtype
|
| 1744 |
+
|
| 1745 |
+
def __str__(self) -> str:
|
| 1746 |
+
return self.name
|
| 1747 |
+
|
| 1748 |
+
def __hash__(self) -> int:
|
| 1749 |
+
return hash(self.name)
|
| 1750 |
+
|
| 1751 |
+
def __eq__(self, other: object) -> bool:
|
| 1752 |
+
return isinstance(other, CSEVariable) and other.name == self.name
|
| 1753 |
+
|
| 1754 |
+
def update_on_args(self, name: str, args: Any, kwargs: Any) -> None:
|
| 1755 |
+
pass
|
| 1756 |
+
|
| 1757 |
+
def __repr__(self) -> str:
|
| 1758 |
+
return f"{self.__class__.__name__}({self.name!r})"
|
| 1759 |
+
|
| 1760 |
+
|
| 1761 |
+
AugmentedKeyT = TypeVar("AugmentedKeyT", default=str)
|
| 1762 |
+
CSEVariableType = TypeVar("CSEVariableType", bound=CSEVariable, default=CSEVariable)
|
| 1763 |
+
|
| 1764 |
+
if TYPE_CHECKING:
|
| 1765 |
+
ReductionCacheKey = tuple[
|
| 1766 |
+
torch.dtype,
|
| 1767 |
+
ReductionType,
|
| 1768 |
+
Union[CSEVariable, tuple[CSEVariable, ...]],
|
| 1769 |
+
]
|
| 1770 |
+
|
| 1771 |
+
|
| 1772 |
+
class CSE(Generic[CSEVariableType, AugmentedKeyT]):
|
| 1773 |
+
"""Common subexpression elimination"""
|
| 1774 |
+
|
| 1775 |
+
def __init__(
|
| 1776 |
+
self,
|
| 1777 |
+
prefix: str = "",
|
| 1778 |
+
suffix: str = "",
|
| 1779 |
+
name_prefix: str = "tmp",
|
| 1780 |
+
iter_buffers: Optional[itertools.count[int]] = None,
|
| 1781 |
+
store_cache: Optional[MutableMapping[str, CSEVariableType]] = None,
|
| 1782 |
+
reduction_cache: Optional[
|
| 1783 |
+
MutableMapping[ReductionCacheKey, CSEVariableType]
|
| 1784 |
+
] = None,
|
| 1785 |
+
varname_map: Optional[dict[str, CSEVariableType]] = None,
|
| 1786 |
+
):
|
| 1787 |
+
self.prefix = prefix
|
| 1788 |
+
self.suffix = suffix
|
| 1789 |
+
self._cache: MutableMapping[AugmentedKeyT, CSEVariableType] = {}
|
| 1790 |
+
self.name_prefix = name_prefix
|
| 1791 |
+
self.store_cache: MutableMapping[str, CSEVariableType] = store_cache or {}
|
| 1792 |
+
self.reduction_cache: MutableMapping[ReductionCacheKey, CSEVariableType] = (
|
| 1793 |
+
reduction_cache or {}
|
| 1794 |
+
)
|
| 1795 |
+
self.iter_buffer_ids: itertools.count[int] = iter_buffers or itertools.count()
|
| 1796 |
+
self.invalidated_stores: OrderedSet[str] = OrderedSet()
|
| 1797 |
+
self.varname_map: dict[str, CSEVariableType] = varname_map or {}
|
| 1798 |
+
|
| 1799 |
+
def invalidate(self, keep_vars: OrderedSet[CSEVariable]) -> None:
|
| 1800 |
+
for name, tmp in [*self.store_cache.items()]:
|
| 1801 |
+
if tmp not in keep_vars:
|
| 1802 |
+
del self.store_cache[name]
|
| 1803 |
+
self.invalidated_stores.add(name)
|
| 1804 |
+
if keep_vars:
|
| 1805 |
+
self._cache = {k: v for k, v in self._cache.items() if v in keep_vars}
|
| 1806 |
+
else:
|
| 1807 |
+
self._cache = {}
|
| 1808 |
+
|
| 1809 |
+
def clone(self) -> Self:
|
| 1810 |
+
return type(self)(
|
| 1811 |
+
prefix=self.prefix,
|
| 1812 |
+
suffix=self.suffix,
|
| 1813 |
+
name_prefix=self.name_prefix,
|
| 1814 |
+
iter_buffers=self.iter_buffer_ids,
|
| 1815 |
+
store_cache=self.store_cache,
|
| 1816 |
+
varname_map=self.varname_map,
|
| 1817 |
+
reduction_cache=self.reduction_cache,
|
| 1818 |
+
)
|
| 1819 |
+
|
| 1820 |
+
def scoped_copy(self) -> Self:
|
| 1821 |
+
"""Return a copy of using ScopedDict so changes to *_cache aren't visible in self"""
|
| 1822 |
+
new_cse = self.clone()
|
| 1823 |
+
new_cse._cache = ScopedDict(self._cache)
|
| 1824 |
+
new_cse.reduction_cache = ScopedDict(self.reduction_cache)
|
| 1825 |
+
new_cse.store_cache = ScopedDict(self.store_cache)
|
| 1826 |
+
return new_cse
|
| 1827 |
+
|
| 1828 |
+
def augment_key(self, cache_key: str) -> AugmentedKeyT:
|
| 1829 |
+
"Override this method to augment cache key with backend specifics"
|
| 1830 |
+
return cast(AugmentedKeyT, cache_key)
|
| 1831 |
+
|
| 1832 |
+
def put(self, cache_key: str, val: CSEVariableType) -> None:
|
| 1833 |
+
self._cache[self.augment_key(cache_key)] = val
|
| 1834 |
+
|
| 1835 |
+
def contains(self, cache_key: str) -> bool:
|
| 1836 |
+
return self.augment_key(cache_key) in self._cache
|
| 1837 |
+
|
| 1838 |
+
def try_get(self, cache_key: str) -> Optional[CSEVariableType]:
|
| 1839 |
+
return self._cache.get(self.augment_key(cache_key), None)
|
| 1840 |
+
|
| 1841 |
+
def get(self, cache_key: str) -> CSEVariableType:
|
| 1842 |
+
return self._cache[self.augment_key(cache_key)]
|
| 1843 |
+
|
| 1844 |
+
def generate(
|
| 1845 |
+
self,
|
| 1846 |
+
buffer: IndentedBuffer,
|
| 1847 |
+
expr: Union[str, CSEVariable, OpsValue, IndentedBuffer, DeferredLineBase],
|
| 1848 |
+
*,
|
| 1849 |
+
bounds: ValueRanges[Any] = ValueRanges.unknown(),
|
| 1850 |
+
write: bool = True,
|
| 1851 |
+
assignment: bool = True,
|
| 1852 |
+
dtype: Optional[torch.dtype] = None,
|
| 1853 |
+
) -> CSEVariableType:
|
| 1854 |
+
if isinstance(expr, OpsValue):
|
| 1855 |
+
expr = expr.value
|
| 1856 |
+
|
| 1857 |
+
assert write or assignment
|
| 1858 |
+
if isinstance(expr, CSEVariable):
|
| 1859 |
+
# If the expressions were always created with all the information, we could
|
| 1860 |
+
# assert expr.bounds == bounds, but sometimes the expression is created
|
| 1861 |
+
# with the loose ValueRanges.unknown(), so we need to tighten the bounds
|
| 1862 |
+
expr.bounds = expr.bounds.tighten(bounds)
|
| 1863 |
+
expr.use_count += 1
|
| 1864 |
+
return cast(CSEVariableType, expr)
|
| 1865 |
+
elif isinstance(expr, IndentedBuffer):
|
| 1866 |
+
cache_key = expr.getvalue()
|
| 1867 |
+
elif isinstance(expr, DeferredLineBase):
|
| 1868 |
+
cache_key = expr.line
|
| 1869 |
+
else:
|
| 1870 |
+
assert isinstance(expr, str)
|
| 1871 |
+
cache_key = expr
|
| 1872 |
+
var = self.try_get(cache_key)
|
| 1873 |
+
if not var:
|
| 1874 |
+
var = self.newvar(bounds, dtype)
|
| 1875 |
+
self.put(cache_key, var)
|
| 1876 |
+
if write:
|
| 1877 |
+
if V.kernel.current_node:
|
| 1878 |
+
V.kernel.current_node.codegen_originating_info(
|
| 1879 |
+
buffer, only_once=True
|
| 1880 |
+
)
|
| 1881 |
+
if isinstance(expr, IndentedBuffer):
|
| 1882 |
+
if assignment:
|
| 1883 |
+
buffer.writeline(f"{self.prefix}{var} =")
|
| 1884 |
+
buffer.splice(expr)
|
| 1885 |
+
buffer.writeline(self.suffix)
|
| 1886 |
+
elif isinstance(expr, DeferredLineBase):
|
| 1887 |
+
assert assignment
|
| 1888 |
+
buffer.writeline(
|
| 1889 |
+
expr._new_line(f"{self.prefix}{var} = {expr.line}{self.suffix}")
|
| 1890 |
+
)
|
| 1891 |
+
else:
|
| 1892 |
+
if assignment:
|
| 1893 |
+
line = f"{self.prefix}{var} = {expr}{self.suffix}"
|
| 1894 |
+
else:
|
| 1895 |
+
line = f"{expr}{self.suffix}"
|
| 1896 |
+
buffer.writeline(line)
|
| 1897 |
+
|
| 1898 |
+
# cpp backend cannot determine is_vec at this point
|
| 1899 |
+
if (
|
| 1900 |
+
assignment
|
| 1901 |
+
and (
|
| 1902 |
+
config.test_configs.runtime_triton_dtype_assert
|
| 1903 |
+
or config.test_configs.static_cpp_dtype_assert
|
| 1904 |
+
)
|
| 1905 |
+
and dtype is not None
|
| 1906 |
+
and get_current_backend() != "cpp"
|
| 1907 |
+
):
|
| 1908 |
+
check_dtype(buffer, var, dtype)
|
| 1909 |
+
|
| 1910 |
+
else:
|
| 1911 |
+
var.bounds = var.bounds.tighten(bounds)
|
| 1912 |
+
var.use_count += 1
|
| 1913 |
+
|
| 1914 |
+
return var
|
| 1915 |
+
|
| 1916 |
+
def newvar(
|
| 1917 |
+
self,
|
| 1918 |
+
bounds: ValueRanges[Any] = ValueRanges.unknown(),
|
| 1919 |
+
dtype: Optional[torch.dtype] = None,
|
| 1920 |
+
) -> CSEVariableType:
|
| 1921 |
+
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
|
| 1922 |
+
var = V.kernel.create_cse_var(var_name, bounds, dtype)
|
| 1923 |
+
self.varname_map[var_name] = var
|
| 1924 |
+
return var
|
| 1925 |
+
|
| 1926 |
+
def namedvar(
|
| 1927 |
+
self,
|
| 1928 |
+
name: str,
|
| 1929 |
+
bounds: ValueRanges[Any] = ValueRanges.unknown(),
|
| 1930 |
+
dtype: Optional[torch.dtype] = None,
|
| 1931 |
+
) -> CSEVariableType:
|
| 1932 |
+
torch._check_value(
|
| 1933 |
+
name not in self.varname_map, lambda: f"duplicate name: {name}"
|
| 1934 |
+
)
|
| 1935 |
+
var = V.kernel.create_cse_var(name, bounds, dtype)
|
| 1936 |
+
self.varname_map[name] = var
|
| 1937 |
+
return var
|
| 1938 |
+
|
| 1939 |
+
|
| 1940 |
+
class CodeGen:
|
| 1941 |
+
def __init__(self) -> None:
|
| 1942 |
+
super().__init__()
|
| 1943 |
+
self.exit_stack = contextlib.ExitStack()
|
| 1944 |
+
|
| 1945 |
+
def __enter__(self) -> Self:
|
| 1946 |
+
self.exit_stack.__enter__()
|
| 1947 |
+
return self
|
| 1948 |
+
|
| 1949 |
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
| 1950 |
+
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
|
| 1951 |
+
|
| 1952 |
+
|
| 1953 |
+
class Kernel(CodeGen, Generic[CSEVariableType]):
|
| 1954 |
+
newvar_prefix: str = ""
|
| 1955 |
+
suffix: str = ""
|
| 1956 |
+
overrides: Optional[Callable[[], OpsHandler[Any]]] = None
|
| 1957 |
+
|
| 1958 |
+
def __init__(
|
| 1959 |
+
self, args: Optional[KernelArgs] = None, increase_kernel_count: bool = True
|
| 1960 |
+
) -> None:
|
| 1961 |
+
super().__init__()
|
| 1962 |
+
if increase_kernel_count:
|
| 1963 |
+
metrics.generated_kernel_count += 1
|
| 1964 |
+
self.args = args or KernelArgs()
|
| 1965 |
+
self.loads = IndentedBuffer()
|
| 1966 |
+
self.compute = IndentedBuffer()
|
| 1967 |
+
self.stores = IndentedBuffer()
|
| 1968 |
+
|
| 1969 |
+
self.num_load = 0
|
| 1970 |
+
self.num_reduction = 0
|
| 1971 |
+
|
| 1972 |
+
self.cse: CSE[CSEVariableType, Any] = CSE(self.newvar_prefix, self.suffix)
|
| 1973 |
+
self.must_keep_buffers: OrderedSet[str] = OrderedSet()
|
| 1974 |
+
self.store_buffer_names: OrderedSet[str] = OrderedSet()
|
| 1975 |
+
self._load_mask: Optional[str] = None
|
| 1976 |
+
self._load_other: Union[None, int, float] = None
|
| 1977 |
+
# OrderedSet in set_current_node
|
| 1978 |
+
self.current_node: Optional[SchedulerNode] = None
|
| 1979 |
+
self.node_to_bounds: Optional[dict[torch.fx.Node, ValueRanges[Any]]] = None
|
| 1980 |
+
|
| 1981 |
+
self.removed_buffers: OrderedSet[str] = OrderedSet()
|
| 1982 |
+
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
|
| 1983 |
+
|
| 1984 |
+
# key: the buffer to write
|
| 1985 |
+
# value: the buffer to read and whose memory can be reused for
|
| 1986 |
+
# the buffer specified by key
|
| 1987 |
+
self.inplace_update_buffers: dict[str, str] = {}
|
| 1988 |
+
# Set minimum number of elements processed per thread.
|
| 1989 |
+
self.min_elem_per_thread = 1
|
| 1990 |
+
self.kernel_name: Optional[str] = None
|
| 1991 |
+
|
| 1992 |
+
@contextlib.contextmanager
|
| 1993 |
+
def set_current_node(self, node: SchedulerNode) -> Iterator[None]:
|
| 1994 |
+
prior = self.current_node
|
| 1995 |
+
self.current_node = node
|
| 1996 |
+
self.node_to_bounds = node._body.bounds().get_bounds()
|
| 1997 |
+
try:
|
| 1998 |
+
yield
|
| 1999 |
+
finally:
|
| 2000 |
+
self.current_node = prior
|
| 2001 |
+
|
| 2002 |
+
@contextlib.contextmanager
|
| 2003 |
+
def swap_buffers(
|
| 2004 |
+
self,
|
| 2005 |
+
lb: IndentedBuffer,
|
| 2006 |
+
cb: Optional[IndentedBuffer] = None,
|
| 2007 |
+
sb: Optional[IndentedBuffer] = None,
|
| 2008 |
+
) -> Iterator[None]:
|
| 2009 |
+
if cb is None:
|
| 2010 |
+
cb = lb
|
| 2011 |
+
if disallow_stores := sb is None:
|
| 2012 |
+
sb = IndentedBuffer()
|
| 2013 |
+
loads = self.loads
|
| 2014 |
+
compute = self.compute
|
| 2015 |
+
stores = self.stores
|
| 2016 |
+
cse = self.cse
|
| 2017 |
+
self.loads = lb
|
| 2018 |
+
self.compute = cb
|
| 2019 |
+
self.stores = sb
|
| 2020 |
+
self.cse = cse.scoped_copy()
|
| 2021 |
+
try:
|
| 2022 |
+
yield
|
| 2023 |
+
finally:
|
| 2024 |
+
self.loads = loads
|
| 2025 |
+
self.compute = compute
|
| 2026 |
+
self.stores = stores
|
| 2027 |
+
self.cse = cse
|
| 2028 |
+
if disallow_stores:
|
| 2029 |
+
assert not sb, "unexpected store inside swap_buffers"
|
| 2030 |
+
|
| 2031 |
+
def load(self, name: str, index: sympy.Expr) -> CSEVariable:
|
| 2032 |
+
raise NotImplementedError
|
| 2033 |
+
|
| 2034 |
+
def indirect_load(self, name: str, index: sympy.Expr) -> CSEVariable:
|
| 2035 |
+
"""A load the depends on an index we have read"""
|
| 2036 |
+
prior = self.loads
|
| 2037 |
+
try:
|
| 2038 |
+
# put the load in the compute section as it might have deps
|
| 2039 |
+
self.loads = self.compute
|
| 2040 |
+
return self.load(name, index)
|
| 2041 |
+
finally:
|
| 2042 |
+
self.loads = prior
|
| 2043 |
+
|
| 2044 |
+
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
|
| 2045 |
+
raise NotImplementedError
|
| 2046 |
+
|
| 2047 |
+
def store(
|
| 2048 |
+
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
| 2049 |
+
) -> None:
|
| 2050 |
+
raise NotImplementedError
|
| 2051 |
+
|
| 2052 |
+
def reduction(
|
| 2053 |
+
self,
|
| 2054 |
+
dtype: torch.dtype,
|
| 2055 |
+
src_dtype: torch.dtype,
|
| 2056 |
+
reduction_type: ReductionType,
|
| 2057 |
+
value: Union[CSEVariable, tuple[CSEVariable, ...]],
|
| 2058 |
+
) -> Union[CSEVariable, tuple[CSEVariable, ...]]:
|
| 2059 |
+
raise NotImplementedError
|
| 2060 |
+
|
| 2061 |
+
def scan(
|
| 2062 |
+
self,
|
| 2063 |
+
dtypes: tuple[torch.dtype, ...],
|
| 2064 |
+
combine_fn: Callable[
|
| 2065 |
+
[tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...]
|
| 2066 |
+
],
|
| 2067 |
+
values: tuple[CSEVariable, ...],
|
| 2068 |
+
) -> tuple[CSEVariable, ...]:
|
| 2069 |
+
raise NotImplementedError
|
| 2070 |
+
|
| 2071 |
+
def sort(
|
| 2072 |
+
self,
|
| 2073 |
+
dtypes: tuple[torch.dtype, ...],
|
| 2074 |
+
values: tuple[CSEVariable, ...],
|
| 2075 |
+
stable: bool,
|
| 2076 |
+
descending: bool,
|
| 2077 |
+
) -> tuple[CSEVariable, ...]:
|
| 2078 |
+
raise NotImplementedError
|
| 2079 |
+
|
| 2080 |
+
def var_ranges(self) -> dict[sympy.Symbol, sympy.Expr]:
|
| 2081 |
+
raise NotImplementedError
|
| 2082 |
+
|
| 2083 |
+
def bucketize(
|
| 2084 |
+
self,
|
| 2085 |
+
values: CSEVariable,
|
| 2086 |
+
boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
|
| 2087 |
+
boundary_indices: CSEVariable,
|
| 2088 |
+
indexing_dtype: torch.dtype,
|
| 2089 |
+
right: bool,
|
| 2090 |
+
sorter: Optional[tuple[str, sympy.Expr]] = None,
|
| 2091 |
+
sorter_indices: Optional[CSEVariable] = None,
|
| 2092 |
+
) -> CSEVariable:
|
| 2093 |
+
"""
|
| 2094 |
+
See [Note: Inductor bucketize op]
|
| 2095 |
+
"""
|
| 2096 |
+
raise NotImplementedError
|
| 2097 |
+
|
| 2098 |
+
@property
|
| 2099 |
+
def assert_function(self) -> str:
|
| 2100 |
+
raise NotImplementedError
|
| 2101 |
+
|
| 2102 |
+
def indirect_assert(
|
| 2103 |
+
self,
|
| 2104 |
+
var: Union[CSEVariable, str],
|
| 2105 |
+
lower: Optional[str],
|
| 2106 |
+
upper: Optional[str],
|
| 2107 |
+
mask: Optional[Union[CSEVariable, str]] = None,
|
| 2108 |
+
) -> str:
|
| 2109 |
+
if isinstance(var, CSEVariable):
|
| 2110 |
+
var = str(var)
|
| 2111 |
+
assert isinstance(var, str), type(var)
|
| 2112 |
+
assert lower is None or isinstance(lower, str)
|
| 2113 |
+
assert upper is None or isinstance(upper, str)
|
| 2114 |
+
if lower and upper:
|
| 2115 |
+
# The conditions need to be in parens because of Python's operator precedence.
|
| 2116 |
+
# It'd be less error-prone to use and/or/not, which is supported by triton
|
| 2117 |
+
cond = f"({lower} <= {var}) & ({var} < {upper})"
|
| 2118 |
+
cond_print = f"{lower} <= {var} < {upper}"
|
| 2119 |
+
elif lower:
|
| 2120 |
+
cond = f"{lower} <= {var}"
|
| 2121 |
+
cond_print = cond
|
| 2122 |
+
else:
|
| 2123 |
+
assert upper
|
| 2124 |
+
cond = f"{var} < {upper}"
|
| 2125 |
+
cond_print = cond
|
| 2126 |
+
|
| 2127 |
+
if mask:
|
| 2128 |
+
cond = f"({cond}) | ~({mask})"
|
| 2129 |
+
|
| 2130 |
+
return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
|
| 2131 |
+
|
| 2132 |
+
def check_bounds(
|
| 2133 |
+
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
| 2134 |
+
) -> None:
|
| 2135 |
+
raise NotImplementedError
|
| 2136 |
+
|
| 2137 |
+
def index_to_str(self, index: sympy.Expr) -> str:
|
| 2138 |
+
raise NotImplementedError
|
| 2139 |
+
|
| 2140 |
+
def __enter__(self) -> Self:
|
| 2141 |
+
super().__enter__()
|
| 2142 |
+
assert self.overrides
|
| 2143 |
+
self.exit_stack.enter_context(
|
| 2144 |
+
V.set_ops_handler(CSEProxy(self, self.overrides()))
|
| 2145 |
+
)
|
| 2146 |
+
self.exit_stack.enter_context(V.set_kernel_handler(self))
|
| 2147 |
+
return self
|
| 2148 |
+
|
| 2149 |
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
| 2150 |
+
self.remove_kernel_local_buffers()
|
| 2151 |
+
super().__exit__(exc_type, exc_val, exc_tb)
|
| 2152 |
+
|
| 2153 |
+
def remove_kernel_local_buffers(self) -> None:
|
| 2154 |
+
"""
|
| 2155 |
+
Any buffers that are both created and have a last use in the
|
| 2156 |
+
same kernel can be removed.
|
| 2157 |
+
|
| 2158 |
+
Note that V.graph.scheduler can be None when codegening triton template
|
| 2159 |
+
kernels.
|
| 2160 |
+
"""
|
| 2161 |
+
scheduler = V.graph.scheduler
|
| 2162 |
+
if not scheduler:
|
| 2163 |
+
return
|
| 2164 |
+
fused_node_names = OrderedSet(
|
| 2165 |
+
scheduler.name_to_buf[buf].defining_op_name()
|
| 2166 |
+
for buf in self.store_buffer_names
|
| 2167 |
+
if buf in scheduler.name_to_buf
|
| 2168 |
+
)
|
| 2169 |
+
names_to_remove: OrderedSet[str] = OrderedSet()
|
| 2170 |
+
for name in self.store_buffer_names:
|
| 2171 |
+
if (
|
| 2172 |
+
name not in self.must_keep_buffers
|
| 2173 |
+
and name not in self.args.input_buffers
|
| 2174 |
+
and scheduler.can_buffer_be_removed_through_fusion(
|
| 2175 |
+
name, fused_node_names
|
| 2176 |
+
)
|
| 2177 |
+
):
|
| 2178 |
+
names_to_remove.add(name)
|
| 2179 |
+
|
| 2180 |
+
for name in names_to_remove:
|
| 2181 |
+
if name in self.args.inplace_buffers:
|
| 2182 |
+
buf = self.args.inplace_buffers[name]
|
| 2183 |
+
if isinstance(buf, RemovedArg):
|
| 2184 |
+
continue
|
| 2185 |
+
remove = all(n in names_to_remove for n in buf.other_names)
|
| 2186 |
+
if remove:
|
| 2187 |
+
self.remove_inplace_buffer(name)
|
| 2188 |
+
self.inplaced_to_remove.add(name)
|
| 2189 |
+
else:
|
| 2190 |
+
self.remove_buffer(name)
|
| 2191 |
+
|
| 2192 |
+
def remove_buffer(self, name: str) -> None:
|
| 2193 |
+
# Assign a special value instead of deleting the entry
|
| 2194 |
+
# because we still rely on output_buffers's length to
|
| 2195 |
+
# generate unique arg name.
|
| 2196 |
+
log.debug("remove_buffer(%r)", name)
|
| 2197 |
+
self.args.output_buffers[name] = REMOVED
|
| 2198 |
+
self.removed_buffers.add(name)
|
| 2199 |
+
|
| 2200 |
+
def remove_inplace_buffer(self, name: str) -> None:
|
| 2201 |
+
log.debug("removing_inplace_buffer(%r)", name)
|
| 2202 |
+
self.args.inplace_buffers[name] = REMOVED
|
| 2203 |
+
self.removed_buffers.add(name)
|
| 2204 |
+
|
| 2205 |
+
def rename_indexing(
|
| 2206 |
+
self, index: Union[list[sympy.Expr], tuple[sympy.Expr, ...], sympy.Expr]
|
| 2207 |
+
) -> sympy.Expr:
|
| 2208 |
+
# adds the necessary kernel args for index expressions
|
| 2209 |
+
# and renames variables in index expressions to kernel arg names
|
| 2210 |
+
if isinstance(index, (list, tuple)):
|
| 2211 |
+
return [self.rename_indexing(x) for x in index]
|
| 2212 |
+
index = V.graph.sizevars.simplify(index)
|
| 2213 |
+
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
|
| 2214 |
+
replacements = {
|
| 2215 |
+
x: self.args.size(x)
|
| 2216 |
+
for x in sorted_symbols
|
| 2217 |
+
if symbol_is_type(
|
| 2218 |
+
x,
|
| 2219 |
+
(
|
| 2220 |
+
SymT.UNBACKED_INT,
|
| 2221 |
+
SymT.SIZE,
|
| 2222 |
+
SymT.PRECOMPUTED_SIZE,
|
| 2223 |
+
),
|
| 2224 |
+
)
|
| 2225 |
+
}
|
| 2226 |
+
return sympy_subs(index, replacements)
|
| 2227 |
+
|
| 2228 |
+
def create_cse_var(self, *args: Any, **kwargs: Any) -> CSEVariable:
|
| 2229 |
+
return CSEVariable(*args, **kwargs)
|
| 2230 |
+
|
| 2231 |
+
def arg_name(self, node: IRNode) -> Optional[str]:
|
| 2232 |
+
"""
|
| 2233 |
+
Returns arg name of a given input or output node.
|
| 2234 |
+
"""
|
| 2235 |
+
if node is None:
|
| 2236 |
+
return None
|
| 2237 |
+
return self.args.arg_name(node.get_name())
|
| 2238 |
+
|
| 2239 |
+
|
| 2240 |
+
@dataclasses.dataclass
|
| 2241 |
+
class OptimizationContext:
|
| 2242 |
+
key: ClassVar[str] = "opt_ctx"
|
| 2243 |
+
|
| 2244 |
+
dtype: Optional[torch.dtype] = None
|
| 2245 |
+
ops_name: str = ""
|
| 2246 |
+
|
| 2247 |
+
|
| 2248 |
+
@functools.cache
|
| 2249 |
+
def jinja2_env() -> Any:
|
| 2250 |
+
try:
|
| 2251 |
+
import jinja2
|
| 2252 |
+
|
| 2253 |
+
return jinja2.Environment(
|
| 2254 |
+
undefined=jinja2.StrictUndefined,
|
| 2255 |
+
)
|
| 2256 |
+
except ImportError:
|
| 2257 |
+
return None
|
| 2258 |
+
|
| 2259 |
+
|
| 2260 |
+
class KernelTemplate:
|
| 2261 |
+
"""
|
| 2262 |
+
Base class for defining kernel templates.
|
| 2263 |
+
|
| 2264 |
+
Children classes: TritonTemplate, CUDATemplate
|
| 2265 |
+
"""
|
| 2266 |
+
|
| 2267 |
+
@staticmethod
|
| 2268 |
+
def indent_except_first(
|
| 2269 |
+
source: str, num_indents: int, indents_spacing: int = 4
|
| 2270 |
+
) -> str:
|
| 2271 |
+
lines = source.splitlines(True)
|
| 2272 |
+
if len(lines) > 1:
|
| 2273 |
+
lines[1:] = [
|
| 2274 |
+
(" " * indents_spacing * num_indents) + line for line in lines[1:]
|
| 2275 |
+
]
|
| 2276 |
+
return "".join(lines)
|
| 2277 |
+
|
| 2278 |
+
@staticmethod
|
| 2279 |
+
def _template_from_string(source: str) -> Any:
|
| 2280 |
+
env = jinja2_env()
|
| 2281 |
+
if env is None:
|
| 2282 |
+
return None
|
| 2283 |
+
env.filters["indent_except_first"] = KernelTemplate.indent_except_first
|
| 2284 |
+
from jinja2 import TemplateSyntaxError
|
| 2285 |
+
|
| 2286 |
+
try:
|
| 2287 |
+
return env.from_string(source)
|
| 2288 |
+
except TemplateSyntaxError as e:
|
| 2289 |
+
|
| 2290 |
+
class DetailedTemplateSyntaxError(TemplateSyntaxError):
|
| 2291 |
+
def __init__(self, original_error: TemplateSyntaxError) -> None:
|
| 2292 |
+
super().__init__(
|
| 2293 |
+
original_error.message,
|
| 2294 |
+
original_error.lineno,
|
| 2295 |
+
original_error.name,
|
| 2296 |
+
original_error.filename,
|
| 2297 |
+
)
|
| 2298 |
+
self.original_error = original_error
|
| 2299 |
+
|
| 2300 |
+
def __str__(self) -> str:
|
| 2301 |
+
error_info = f"Error in template at line {self.lineno}\n"
|
| 2302 |
+
error_info += f"Error message: {self.message}\n"
|
| 2303 |
+
if hasattr(self.original_error, "source"):
|
| 2304 |
+
lines = self.original_error.source.split("\n")
|
| 2305 |
+
error_info += "Context:\n"
|
| 2306 |
+
start = max(0, self.lineno - 2)
|
| 2307 |
+
end = min(len(lines), self.lineno + 2)
|
| 2308 |
+
for i in range(start, end):
|
| 2309 |
+
if i == self.lineno - 1:
|
| 2310 |
+
error_info += f"{i + 1}: --> {lines[i]}\n"
|
| 2311 |
+
if hasattr(self.original_error, "column"):
|
| 2312 |
+
error_info += (
|
| 2313 |
+
" "
|
| 2314 |
+
+ " " * (self.original_error.column - 1)
|
| 2315 |
+
+ "^\n"
|
| 2316 |
+
)
|
| 2317 |
+
else:
|
| 2318 |
+
error_info += f"{i + 1}: {lines[i]}\n"
|
| 2319 |
+
return error_info
|
| 2320 |
+
|
| 2321 |
+
raise DetailedTemplateSyntaxError(e) from e
|
| 2322 |
+
|
| 2323 |
+
@staticmethod
|
| 2324 |
+
def _fake_get_dtype(
|
| 2325 |
+
fake_outs: Union[list[Buffer], Buffer],
|
| 2326 |
+
) -> Callable[[str], torch.dtype]:
|
| 2327 |
+
_get_dtype_real = V.graph.get_dtype
|
| 2328 |
+
if isinstance(fake_outs, (list, tuple)):
|
| 2329 |
+
lookup = {buf.get_name(): buf.get_dtype() for buf in fake_outs}
|
| 2330 |
+
else:
|
| 2331 |
+
lookup = {fake_outs.get_name(): fake_outs.get_dtype()}
|
| 2332 |
+
|
| 2333 |
+
def get_dtype(name: str) -> torch.dtype:
|
| 2334 |
+
result = lookup.get(name)
|
| 2335 |
+
if result is not None:
|
| 2336 |
+
return result
|
| 2337 |
+
return _get_dtype_real(name)
|
| 2338 |
+
|
| 2339 |
+
return get_dtype
|
| 2340 |
+
|
| 2341 |
+
def __init__(self, name: str) -> None:
|
| 2342 |
+
self.name = name
|
| 2343 |
+
|
| 2344 |
+
def maybe_append_choice(
|
| 2345 |
+
self, choices: list[Any], **kwargs: Any
|
| 2346 |
+
) -> Optional[NotImplementedError]:
|
| 2347 |
+
"""
|
| 2348 |
+
Maybe generates a new ChoiceCaller and appends it into existing choices.
|
| 2349 |
+
Returns None if success, otherwise returns the error.
|
| 2350 |
+
|
| 2351 |
+
choices: A list of ChoiceCallers.
|
| 2352 |
+
kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
|
| 2353 |
+
"""
|
| 2354 |
+
|
| 2355 |
+
try:
|
| 2356 |
+
choices.append(self.generate(**kwargs))
|
| 2357 |
+
return None
|
| 2358 |
+
except NotImplementedError as e:
|
| 2359 |
+
log.info(
|
| 2360 |
+
"Cannot Append Choice: %s. KernelTemplate type is %s",
|
| 2361 |
+
e,
|
| 2362 |
+
type(self),
|
| 2363 |
+
stack_info=log.getEffectiveLevel() < logging.INFO,
|
| 2364 |
+
)
|
| 2365 |
+
return e
|
| 2366 |
+
|
| 2367 |
+
def generate(self, **kwargs: Any) -> ChoiceCaller:
|
| 2368 |
+
"""
|
| 2369 |
+
Generates a ChoiceCaller instance from the given arguments.
|
| 2370 |
+
"""
|
| 2371 |
+
|
| 2372 |
+
raise NotImplementedError
|
| 2373 |
+
|
| 2374 |
+
|
| 2375 |
+
class CSEProxy(DefaultHandler):
|
| 2376 |
+
name = "CSEProxy"
|
| 2377 |
+
|
| 2378 |
+
def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]):
|
| 2379 |
+
super().__init__()
|
| 2380 |
+
from ..bounds import ValueRangeAnalysis
|
| 2381 |
+
|
| 2382 |
+
self.vr_analysis = ValueRangeAnalysis()
|
| 2383 |
+
self.kernel = kernel
|
| 2384 |
+
self.parent_handler = parent_handler
|
| 2385 |
+
|
| 2386 |
+
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
| 2387 |
+
bounds = self._bound_variable(name, *args, **kwargs)
|
| 2388 |
+
|
| 2389 |
+
value = getattr(self.parent_handler, name)(*args, **kwargs)
|
| 2390 |
+
dtype_handler = DtypePropagationOpsHandler()
|
| 2391 |
+
|
| 2392 |
+
backend = get_current_backend()
|
| 2393 |
+
|
| 2394 |
+
output_dtype = None
|
| 2395 |
+
if name == "masked" and backend == "triton":
|
| 2396 |
+
output_dtype = value.dtype
|
| 2397 |
+
elif name == "masked" and backend == "cpp":
|
| 2398 |
+
output_dtype = V.interpreter.current_node.meta.get(
|
| 2399 |
+
OptimizationContext.key, None
|
| 2400 |
+
).dtype
|
| 2401 |
+
elif backend in ("triton", "cpp", "mps"):
|
| 2402 |
+
dtype_op = getattr(dtype_handler, name)
|
| 2403 |
+
output_dtype = dtype_op(*args, **kwargs)
|
| 2404 |
+
|
| 2405 |
+
if backend in ("triton", "cpp"):
|
| 2406 |
+
# maybe there are some exceptions on mps?
|
| 2407 |
+
assert output_dtype is not None
|
| 2408 |
+
|
| 2409 |
+
output_idx = 0
|
| 2410 |
+
|
| 2411 |
+
def do_cse(v: str) -> CSEVariable:
|
| 2412 |
+
# we tree_map over the output, so we need to fetch corresponding dtype
|
| 2413 |
+
nonlocal output_idx
|
| 2414 |
+
var_dtype: Optional[torch.dtype] = (
|
| 2415 |
+
output_dtype[output_idx]
|
| 2416 |
+
if isinstance(output_dtype, (list, tuple))
|
| 2417 |
+
else output_dtype
|
| 2418 |
+
)
|
| 2419 |
+
output_idx += 1
|
| 2420 |
+
|
| 2421 |
+
# some cpp op implementations don't set the dtype
|
| 2422 |
+
if backend == "cpp" and isinstance(v, CSEVariable) and v.dtype is None:
|
| 2423 |
+
v.dtype = var_dtype
|
| 2424 |
+
|
| 2425 |
+
csevar = V.kernel.cse.generate(
|
| 2426 |
+
V.kernel.compute,
|
| 2427 |
+
v,
|
| 2428 |
+
bounds=bounds,
|
| 2429 |
+
dtype=output_dtype,
|
| 2430 |
+
)
|
| 2431 |
+
|
| 2432 |
+
csevar.update_on_args(name, args, kwargs)
|
| 2433 |
+
|
| 2434 |
+
if (
|
| 2435 |
+
config.test_configs.runtime_triton_dtype_assert
|
| 2436 |
+
or config.test_configs.static_cpp_dtype_assert
|
| 2437 |
+
):
|
| 2438 |
+
assert var_dtype is not None
|
| 2439 |
+
check_dtype(V.kernel.compute, csevar, var_dtype)
|
| 2440 |
+
return csevar
|
| 2441 |
+
|
| 2442 |
+
return pytree.tree_map(do_cse, value)
|
| 2443 |
+
|
| 2444 |
+
def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[Any]:
|
| 2445 |
+
"""
|
| 2446 |
+
If the variable comes from an FX node, we forward the bound we have already computed
|
| 2447 |
+
Else, if the variable when codegen'ing another op, we try to compute its bounds
|
| 2448 |
+
"""
|
| 2449 |
+
from ..bounds import ValueRangeAnalysis
|
| 2450 |
+
from ..select_algorithm import TritonTemplateKernel
|
| 2451 |
+
from .cuda.cuda_kernel import CUDATemplateKernel
|
| 2452 |
+
|
| 2453 |
+
if isinstance(V.kernel, TritonTemplateKernel):
|
| 2454 |
+
return ValueRanges.unknown()
|
| 2455 |
+
|
| 2456 |
+
if isinstance(V.kernel, CUDATemplateKernel):
|
| 2457 |
+
return ValueRanges.unknown()
|
| 2458 |
+
|
| 2459 |
+
fx_node = V.interpreter.current_node
|
| 2460 |
+
if fx_node.target == name and self.kernel.node_to_bounds is not None:
|
| 2461 |
+
assert isinstance(self.kernel.node_to_bounds, dict), type(
|
| 2462 |
+
self.kernel.node_to_bounds
|
| 2463 |
+
)
|
| 2464 |
+
return self.kernel.node_to_bounds.get(fx_node, ValueRanges.unknown())
|
| 2465 |
+
elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
|
| 2466 |
+
# These create lots of inner strings. We would need to compute the bounds at the ops
|
| 2467 |
+
# We will also likely not get much from computing VRs on these nodes
|
| 2468 |
+
if any(s in fx_node.target for s in ("set_indirect", "reduction", "scan")):
|
| 2469 |
+
return ValueRanges.unknown()
|
| 2470 |
+
|
| 2471 |
+
# We assume that the inputs come from `ops.` and are not strings. If you want to generate
|
| 2472 |
+
# intermediary strings, wrap them in CSE variables with properly initialised bounds.
|
| 2473 |
+
|
| 2474 |
+
# If there is no FX bound but we know how to compute one we do so
|
| 2475 |
+
assert not kwargs
|
| 2476 |
+
|
| 2477 |
+
def arg_to_bound(x: Any) -> Any:
|
| 2478 |
+
if isinstance(x, CSEVariable):
|
| 2479 |
+
return x.bounds
|
| 2480 |
+
elif isinstance(x, sympy.Expr):
|
| 2481 |
+
return bound_sympy(x)
|
| 2482 |
+
else:
|
| 2483 |
+
return x
|
| 2484 |
+
|
| 2485 |
+
arg_bounds = list(map(arg_to_bound, args))
|
| 2486 |
+
return getattr(self.vr_analysis, name)(*arg_bounds)
|
| 2487 |
+
return ValueRanges.unknown()
|
| 2488 |
+
|
| 2489 |
+
def indirect_indexing(
|
| 2490 |
+
self,
|
| 2491 |
+
var: CSEVariable,
|
| 2492 |
+
size: Union[sympy.Expr, int],
|
| 2493 |
+
check: bool = True,
|
| 2494 |
+
wrap_neg: bool = True,
|
| 2495 |
+
) -> sympy.Symbol:
|
| 2496 |
+
if isinstance(size, int):
|
| 2497 |
+
size = sympy.Integer(size)
|
| 2498 |
+
assert isinstance(size, sympy.Expr), (type(size), size)
|
| 2499 |
+
# Skip CSE since this doesn't return an expression
|
| 2500 |
+
|
| 2501 |
+
if var.bounds.lower < 0:
|
| 2502 |
+
if wrap_neg:
|
| 2503 |
+
stm = ops.add(var, ops.index_expr(size, torch.long))
|
| 2504 |
+
# Mixed negative and non-negative
|
| 2505 |
+
if var.bounds.upper >= 0:
|
| 2506 |
+
lt = ops.lt(var, 0)
|
| 2507 |
+
stm = ops.where(lt, stm, var)
|
| 2508 |
+
else:
|
| 2509 |
+
stm = var
|
| 2510 |
+
|
| 2511 |
+
# Propagate bounds as we know how to compute them properly
|
| 2512 |
+
new_bounds = ValueRanges.unknown()
|
| 2513 |
+
if var.bounds != ValueRanges.unknown() and isinstance(size, sympy.Number):
|
| 2514 |
+
# Take the negative part of the bound and add size to it
|
| 2515 |
+
# Then take union of that and the positive part
|
| 2516 |
+
# This is a tighter bound than that of a generic ops.where, as we have info on the cond
|
| 2517 |
+
neg_bounds = var.bounds & ValueRanges(-int_oo, -1)
|
| 2518 |
+
new_bounds = ValueRanges(
|
| 2519 |
+
neg_bounds.lower + size, neg_bounds.upper + size
|
| 2520 |
+
)
|
| 2521 |
+
# We don't have a good way of representing the empty range
|
| 2522 |
+
if var.bounds.upper >= 0:
|
| 2523 |
+
pos = var.bounds & ValueRanges(0, int_oo)
|
| 2524 |
+
new_bounds = new_bounds | pos
|
| 2525 |
+
|
| 2526 |
+
var = self.kernel.cse.generate(self.kernel.compute, stm, bounds=new_bounds)
|
| 2527 |
+
|
| 2528 |
+
sympy_var = self.parent_handler.indirect_indexing(var, size, check)
|
| 2529 |
+
if generate_assert(check):
|
| 2530 |
+
assert_lower = not (var.bounds.lower >= 0)
|
| 2531 |
+
# value ranges cannot x < s when x and s are symbols
|
| 2532 |
+
assert_upper = not isinstance(size, sympy.Number) or not (
|
| 2533 |
+
var.bounds.upper < size
|
| 2534 |
+
)
|
| 2535 |
+
self.kernel.check_bounds(sympy_var, size, assert_lower, assert_upper)
|
| 2536 |
+
return sympy_var
|
| 2537 |
+
|
| 2538 |
+
def check_bounds(
|
| 2539 |
+
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
| 2540 |
+
) -> None:
|
| 2541 |
+
return self.kernel.check_bounds(expr, size, lower, upper)
|
| 2542 |
+
|
| 2543 |
+
def load(self, name: str, index: sympy.Expr) -> CSEVariable:
|
| 2544 |
+
if name in self.kernel.cse.invalidated_stores:
|
| 2545 |
+
# A load from an invalidated store requires us to
|
| 2546 |
+
# keep the actual buffer around
|
| 2547 |
+
V.kernel.must_keep_buffers.add(name)
|
| 2548 |
+
if free_symbol_is_type(index, SymT.TMP):
|
| 2549 |
+
return self.kernel.indirect_load(name, index)
|
| 2550 |
+
store_cache = self.kernel.cse.store_cache
|
| 2551 |
+
if name in store_cache:
|
| 2552 |
+
return store_cache[name]
|
| 2553 |
+
out = self.kernel.load(name, index)
|
| 2554 |
+
# count load that is not in the store_cache, and also not in the
|
| 2555 |
+
# cse cache.
|
| 2556 |
+
if out.use_count == 1:
|
| 2557 |
+
self.kernel.num_load += 1
|
| 2558 |
+
return out
|
| 2559 |
+
|
| 2560 |
+
def _update_store_cache(self, name: str, value: CSEVariable) -> None:
|
| 2561 |
+
self.kernel.cse.store_cache[name] = value
|
| 2562 |
+
if self.kernel.current_node and name in V.graph.name_to_buffer:
|
| 2563 |
+
buf = self.kernel.current_node.get_output(name)
|
| 2564 |
+
for other_name in buf.get_mutations():
|
| 2565 |
+
self.kernel.cse.store_cache[other_name] = value
|
| 2566 |
+
|
| 2567 |
+
def store(
|
| 2568 |
+
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
| 2569 |
+
) -> None:
|
| 2570 |
+
self.kernel.store_buffer_names.add(name)
|
| 2571 |
+
if mode is None:
|
| 2572 |
+
self._update_store_cache(name, value)
|
| 2573 |
+
if name not in V.graph.removed_buffers:
|
| 2574 |
+
self.kernel.store(name, index, value, mode=mode)
|
| 2575 |
+
|
| 2576 |
+
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
|
| 2577 |
+
self.kernel.store_buffer_names.add(name)
|
| 2578 |
+
self._update_store_cache(name, value)
|
| 2579 |
+
|
| 2580 |
+
if name not in V.graph.removed_buffers:
|
| 2581 |
+
return self.kernel.store_reduction(name, index, value)
|
| 2582 |
+
|
| 2583 |
+
def reduction(
|
| 2584 |
+
self,
|
| 2585 |
+
dtype: torch.dtype,
|
| 2586 |
+
src_dtype: torch.dtype,
|
| 2587 |
+
reduction_type: ReductionType,
|
| 2588 |
+
value: Union[CSEVariable, tuple[CSEVariable, ...]],
|
| 2589 |
+
) -> Union[CSEVariable, tuple[CSEVariable, ...]]:
|
| 2590 |
+
self.kernel.num_reduction += 1
|
| 2591 |
+
return self.kernel.reduction(dtype, src_dtype, reduction_type, value)
|
| 2592 |
+
|
| 2593 |
+
def scan(
|
| 2594 |
+
self,
|
| 2595 |
+
dtypes: tuple[torch.dtype, ...],
|
| 2596 |
+
combine_fn: Callable[
|
| 2597 |
+
[tuple[CSEVariable, ...], tuple[CSEVariable, ...]],
|
| 2598 |
+
tuple[CSEVariable, ...],
|
| 2599 |
+
],
|
| 2600 |
+
values: tuple[CSEVariable, ...],
|
| 2601 |
+
) -> tuple[CSEVariable, ...]:
|
| 2602 |
+
return self.kernel.scan(dtypes, combine_fn, values)
|
| 2603 |
+
|
| 2604 |
+
def sort(
|
| 2605 |
+
self,
|
| 2606 |
+
dtypes: tuple[torch.dtype, ...],
|
| 2607 |
+
values: tuple[CSEVariable, ...],
|
| 2608 |
+
stable: bool,
|
| 2609 |
+
descending: bool,
|
| 2610 |
+
) -> tuple[CSEVariable, ...]:
|
| 2611 |
+
return self.kernel.sort(dtypes, values, stable, descending)
|
| 2612 |
+
|
| 2613 |
+
def bucketize(
|
| 2614 |
+
self,
|
| 2615 |
+
values: CSEVariable,
|
| 2616 |
+
boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
|
| 2617 |
+
boundary_indices: CSEVariable,
|
| 2618 |
+
indexing_dtype: torch.dtype,
|
| 2619 |
+
right: bool,
|
| 2620 |
+
sorter: Optional[tuple[str, sympy.Expr]] = None,
|
| 2621 |
+
sorter_indices: Optional[CSEVariable] = None,
|
| 2622 |
+
) -> CSEVariable:
|
| 2623 |
+
"""
|
| 2624 |
+
[Note: Inductor bucketize op]
|
| 2625 |
+
|
| 2626 |
+
Inputs:
|
| 2627 |
+
-------
|
| 2628 |
+
values: the values to be bucketized.
|
| 2629 |
+
boundaries: a tuple containing
|
| 2630 |
+
(a) the name of the boundaries tensor (which must be sorted, unless
|
| 2631 |
+
the sorting tensor is present),
|
| 2632 |
+
(b) the length of the tensor in the last dimension (i.e. the length of
|
| 2633 |
+
one set of boundaries),
|
| 2634 |
+
(c) the number of elements in the underlying storage (i.e. the length
|
| 2635 |
+
of the flattened tensor, ignoring striding), and
|
| 2636 |
+
(d) the stride of the tensor in the last dimension.
|
| 2637 |
+
boundary_indices: indices into a flattened version of the boundaries
|
| 2638 |
+
tensor, of the same size and shape as "values". Each index points to
|
| 2639 |
+
the first element in the set of boundaries to be used for the
|
| 2640 |
+
corresponding value.
|
| 2641 |
+
indexing_dtype: the dtype to use when indexing into the boundaries
|
| 2642 |
+
tensor. This must be int64 or int32. This additionally specifies the
|
| 2643 |
+
dtype of the return value.
|
| 2644 |
+
right: see "Details" below.
|
| 2645 |
+
sorter: an optional tuple containing
|
| 2646 |
+
(a) the name of an optional sorting tensor, used to access unsorted
|
| 2647 |
+
boundaries without reordering the boundaries tensor, and
|
| 2648 |
+
(b) the stride of the tensor in the last dimension.
|
| 2649 |
+
The values in the sorting tensor are used as indices into the *last*
|
| 2650 |
+
dimension of the boundaries tensor, with all other indices matching.
|
| 2651 |
+
The size of the sorting and boundaries tensors must be equivalent.
|
| 2652 |
+
sorter_indices: must be present if the sorting array is present; see
|
| 2653 |
+
"boundary_indices" for the equivalent definition for the boundaries
|
| 2654 |
+
tensor.
|
| 2655 |
+
|
| 2656 |
+
Output:
|
| 2657 |
+
-------
|
| 2658 |
+
The buckets each value belongs in, within a given set of boundaries. 0
|
| 2659 |
+
indicates a position before the first boundary, and len(boundaries_set)
|
| 2660 |
+
represents a position after the last boundary.
|
| 2661 |
+
|
| 2662 |
+
Details:
|
| 2663 |
+
--------
|
| 2664 |
+
Given a value and a set of boundaries, calculate the bucket that each
|
| 2665 |
+
value belongs to. This works differently in 1-D and N-D cases.
|
| 2666 |
+
|
| 2667 |
+
for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [0, 4, 4, 8], right=True
|
| 2668 |
+
return = [[ 0, 1, 1, 1], [1, 3, 3, 4]].
|
| 2669 |
+
|
| 2670 |
+
for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [[0, 4], [4, 8]], right=True
|
| 2671 |
+
return = [[ 0, 1, 1, 1], [0, 1, 1, 2]]
|
| 2672 |
+
|
| 2673 |
+
Note that in the N-D boundaries case, the shape of "values" and
|
| 2674 |
+
"boundaries" must match in every dimension _except_ the last.
|
| 2675 |
+
|
| 2676 |
+
When right == False, bucket i refers to range (boundaries[i], boundaries[i+1]].
|
| 2677 |
+
When right == True, bucket i refers to range [boundaries[i], boundaries[i+1]).
|
| 2678 |
+
|
| 2679 |
+
Boundaries must be non-decreasing, or a sorter must be provided which
|
| 2680 |
+
would re-index offsets in a non-decreasing order (e.g. the second output
|
| 2681 |
+
of torch.sort(offsets)). Otherwise, the result is undefined.
|
| 2682 |
+
"""
|
| 2683 |
+
return self.kernel.bucketize(
|
| 2684 |
+
values,
|
| 2685 |
+
boundaries,
|
| 2686 |
+
boundary_indices,
|
| 2687 |
+
indexing_dtype,
|
| 2688 |
+
right,
|
| 2689 |
+
sorter,
|
| 2690 |
+
sorter_indices,
|
| 2691 |
+
)
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_bmm_template.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import contextlib
|
| 3 |
+
import itertools
|
| 4 |
+
from typing import Any, Callable, Optional
|
| 5 |
+
from unittest.mock import patch
|
| 6 |
+
|
| 7 |
+
import sympy
|
| 8 |
+
|
| 9 |
+
from .. import ir
|
| 10 |
+
from ..select_algorithm import PartialRender
|
| 11 |
+
from ..virtualized import V
|
| 12 |
+
from .common import ArgName
|
| 13 |
+
from .cpp_gemm_template import CppGemmTemplate, GEMM_TEMPLATE
|
| 14 |
+
from .cpp_micro_gemm import LayoutType
|
| 15 |
+
from .cpp_template_kernel import CppTemplateKernel
|
| 16 |
+
from .cpp_utils import DTYPE_TO_CPP, GemmBlocking
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# We pass all sizevars present in BY to the GEMM templates so variables are not renamed in the BMM definition
|
| 20 |
+
GEMM_SINGLE_THREAD_MM_STUB = r"""
|
| 21 |
+
{{kernel.def_kernel(
|
| 22 |
+
inputs={"X": X, "W": W},
|
| 23 |
+
outputs={"Y": Y_2d},
|
| 24 |
+
aliases=aliases,
|
| 25 |
+
function_name=kernel_name+"_single_thread_mm",
|
| 26 |
+
extra_sizevars=BY_sizevars + [b_index],
|
| 27 |
+
placeholder="<SINGLE_THREAD_MM_DEF_FOR_BMM>")}}"""
|
| 28 |
+
|
| 29 |
+
GEMM_THREADED_MM_STUB = r"""
|
| 30 |
+
{{kernel.def_kernel(
|
| 31 |
+
inputs={"X": X, "W": W},
|
| 32 |
+
outputs={"Y": Y_2d},
|
| 33 |
+
aliases=aliases,
|
| 34 |
+
function_name=kernel_name+"_threaded_mm",
|
| 35 |
+
extra_sizevars=BY_sizevars + [b_index],
|
| 36 |
+
placeholder="<THREADED_MM_DEF_FOR_BMM>")}}"""
|
| 37 |
+
|
| 38 |
+
BMM_TEMPLATE = r"""
|
| 39 |
+
{{ template.codegen_microkernel_def() }}
|
| 40 |
+
{{ template.codegen_single_thread_gemm() }}
|
| 41 |
+
{{ template.codegen_multi_thread_gemm() }}
|
| 42 |
+
|
| 43 |
+
extern "C"
|
| 44 |
+
{{kernel.def_kernel(inputs={"X": BX, "W": BW}, outputs={"Y": BY}, aliases=aliases)}}
|
| 45 |
+
{
|
| 46 |
+
const int64_t B = {{kernel.size(BY_2d, 0)}};
|
| 47 |
+
{%- if num_threads > 1 %}
|
| 48 |
+
constexpr int64_t num_threads = {{num_threads}};
|
| 49 |
+
int64_t B_single_thread_block = (B / num_threads) * num_threads;
|
| 50 |
+
|
| 51 |
+
#pragma omp parallel for num_threads({{num_threads}})
|
| 52 |
+
{%- else %}
|
| 53 |
+
int64_t B_single_thread_block = B;
|
| 54 |
+
{%- endif %}
|
| 55 |
+
for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) {
|
| 56 |
+
{{template.get_gemm_function_call(
|
| 57 |
+
kernel,
|
| 58 |
+
kernel_name+"_single_thread_mm",
|
| 59 |
+
"<SINGLE_THREAD_CALL_FOR_BMM>",
|
| 60 |
+
b_index="b_start",
|
| 61 |
+
)}}
|
| 62 |
+
}
|
| 63 |
+
for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) {
|
| 64 |
+
{{template.get_gemm_function_call(
|
| 65 |
+
kernel,
|
| 66 |
+
kernel_name+"_threaded_mm",
|
| 67 |
+
"<THREADED_MM_CALL_FOR_BMM>",
|
| 68 |
+
b_index="b_start",
|
| 69 |
+
)}}
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class CppBmmTemplate(CppGemmTemplate):
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
input_nodes,
|
| 79 |
+
layout: ir.Layout,
|
| 80 |
+
num_threads: int,
|
| 81 |
+
register_blocking: GemmBlocking,
|
| 82 |
+
beta=1,
|
| 83 |
+
alpha=1,
|
| 84 |
+
has_bias=False,
|
| 85 |
+
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
| 86 |
+
should_block_weights: bool = False,
|
| 87 |
+
name="bmm",
|
| 88 |
+
):
|
| 89 |
+
"""
|
| 90 |
+
In order to simplify the implementation and increase code reuse, the BMM template implements
|
| 91 |
+
two versions of the GEMM kernel: a single-threaded version and a multi-threaded version.
|
| 92 |
+
GEMM kernels are called in a loop over the batch dimension, with single-threaded GEMM calls
|
| 93 |
+
for all but the last (B % num_threads), which are handled by the multi-threaded GEMM kernel.
|
| 94 |
+
|
| 95 |
+
We use an extra sizevar `b_index` to index the batch dimension, which we pass into the GEMM
|
| 96 |
+
template as a sympy.Symbol. This allows us to slice the 3D batch tensors in the GEMM template
|
| 97 |
+
without any changes to the GEMM template itself.
|
| 98 |
+
"""
|
| 99 |
+
super().__init__(
|
| 100 |
+
input_nodes,
|
| 101 |
+
layout,
|
| 102 |
+
num_threads,
|
| 103 |
+
register_blocking,
|
| 104 |
+
beta=beta,
|
| 105 |
+
alpha=alpha,
|
| 106 |
+
has_bias=has_bias,
|
| 107 |
+
epilogue_creator=epilogue_creator,
|
| 108 |
+
should_block_weights=should_block_weights,
|
| 109 |
+
name=name,
|
| 110 |
+
)
|
| 111 |
+
self.b_index = sympy.Symbol("s_b_index", integer=True, nonnegative=True)
|
| 112 |
+
|
| 113 |
+
@staticmethod
|
| 114 |
+
def get_padded_size(n, block_n, k, should_block_weight):
|
| 115 |
+
if should_block_weight:
|
| 116 |
+
# Tensor is constant or not contiguous, so we will pad and block
|
| 117 |
+
new_size, padded_n = CppGemmTemplate.get_padded_size(
|
| 118 |
+
n, block_n, k, should_block_weight
|
| 119 |
+
)
|
| 120 |
+
# Add the new batch dimension
|
| 121 |
+
new_size.insert(0, -1)
|
| 122 |
+
return new_size, padded_n
|
| 123 |
+
else:
|
| 124 |
+
new_size = [-1, k, n]
|
| 125 |
+
return new_size, n
|
| 126 |
+
|
| 127 |
+
@staticmethod
|
| 128 |
+
def check_if_block_weight(W, micro_gemm):
|
| 129 |
+
assert isinstance(W, ir.IRNode)
|
| 130 |
+
_, n = W.get_size()[-2:]
|
| 131 |
+
result = (
|
| 132 |
+
not W.get_layout().is_contiguous()
|
| 133 |
+
or W.get_name() in V.graph.constants
|
| 134 |
+
or (
|
| 135 |
+
n % micro_gemm.register_blocking.block_n != 0
|
| 136 |
+
and micro_gemm.get_b_layout != LayoutType.NORMAL
|
| 137 |
+
)
|
| 138 |
+
)
|
| 139 |
+
return result
|
| 140 |
+
|
| 141 |
+
def get_gemm_function_call(
|
| 142 |
+
self,
|
| 143 |
+
kernel: CppTemplateKernel,
|
| 144 |
+
function_name: str,
|
| 145 |
+
placeholder: str,
|
| 146 |
+
b_index: str,
|
| 147 |
+
) -> str:
|
| 148 |
+
"""
|
| 149 |
+
Similar to 'def_kernel' in cpp_template_kernel, but instead of generating a function definition,
|
| 150 |
+
generate a function call for the GEMM kernel.
|
| 151 |
+
Args:
|
| 152 |
+
placeholder: The string to replace the function call with
|
| 153 |
+
b_index: The index for slicing the 3D batch tensors
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
def hook():
|
| 157 |
+
arg_defs, call_args, _, _ = kernel.args.python_argdefs()
|
| 158 |
+
for i, buf in enumerate(call_args):
|
| 159 |
+
if buf == self.b_index:
|
| 160 |
+
arg_defs[i] = ArgName(b_index)
|
| 161 |
+
call = f"{function_name}({', '.join(x.full_name() for x in arg_defs)});"
|
| 162 |
+
return call
|
| 163 |
+
|
| 164 |
+
assert placeholder not in kernel.render_hooks
|
| 165 |
+
kernel.render_hooks[placeholder] = hook
|
| 166 |
+
return placeholder
|
| 167 |
+
|
| 168 |
+
def get_default_reindexers(self, epilogue_nodes):
|
| 169 |
+
def reindexer(args):
|
| 170 |
+
# if epilogue nodes exist, they have 3D ranges but args are 2D, so add 0 index
|
| 171 |
+
return [self.b_index] + args
|
| 172 |
+
|
| 173 |
+
return [reindexer] * len(epilogue_nodes)
|
| 174 |
+
|
| 175 |
+
def get_options(
|
| 176 |
+
self,
|
| 177 |
+
kernel: CppTemplateKernel,
|
| 178 |
+
template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
|
| 179 |
+
flag_template_buffer_has_other_users: Optional[bool] = None,
|
| 180 |
+
epilogue_nodes: Optional[list[ir.IRNode]] = None,
|
| 181 |
+
**kwargs,
|
| 182 |
+
) -> dict[str, Any]:
|
| 183 |
+
options = super().get_options(
|
| 184 |
+
kernel=kernel,
|
| 185 |
+
template_buffer_node=template_buffer_node,
|
| 186 |
+
flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
|
| 187 |
+
epilogue_nodes=epilogue_nodes,
|
| 188 |
+
**kwargs,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
BX, BW, BY = options["X"], options["W"], options["Y"]
|
| 192 |
+
options["BX"], options["BW"], options["BY"] = BX, BW, BY
|
| 193 |
+
options["BY_2d"] = options["Y_2d"]
|
| 194 |
+
for kword in ["X", "W", "GemmOut", "Y_2d"]:
|
| 195 |
+
options[kword] = kernel.select(options[kword], 0, self.b_index)
|
| 196 |
+
for kword in ["X", "W", "Y_2d"]:
|
| 197 |
+
options[kword + "_dtype"] = DTYPE_TO_CPP[options[kword].dtype]
|
| 198 |
+
options["b_index"] = self.b_index
|
| 199 |
+
options["BY_sizevars"] = [
|
| 200 |
+
s
|
| 201 |
+
for sym in itertools.chain(BY.get_size(), BY.get_stride())
|
| 202 |
+
if isinstance(sym, sympy.Expr)
|
| 203 |
+
for s in sym.free_symbols
|
| 204 |
+
]
|
| 205 |
+
options["kernel_name"] = kernel.kernel_name
|
| 206 |
+
|
| 207 |
+
return options
|
| 208 |
+
|
| 209 |
+
def render( # type: ignore[override, return]
|
| 210 |
+
self,
|
| 211 |
+
kernel: CppTemplateKernel,
|
| 212 |
+
template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
|
| 213 |
+
flag_template_buffer_has_other_users: Optional[bool] = None,
|
| 214 |
+
epilogue_nodes: Optional[list[ir.IRNode]] = None,
|
| 215 |
+
**kwargs,
|
| 216 |
+
) -> str:
|
| 217 |
+
options = self.get_options(
|
| 218 |
+
kernel=kernel,
|
| 219 |
+
template_buffer_node=template_buffer_node,
|
| 220 |
+
flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
|
| 221 |
+
epilogue_nodes=epilogue_nodes,
|
| 222 |
+
**kwargs,
|
| 223 |
+
)
|
| 224 |
+
self.render_options = options
|
| 225 |
+
|
| 226 |
+
with contextlib.ExitStack() as stack:
|
| 227 |
+
for buf in options["fake_buffers"]:
|
| 228 |
+
stack.enter_context(
|
| 229 |
+
patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
|
| 230 |
+
)
|
| 231 |
+
result = self._template_from_string(BMM_TEMPLATE).render(**options)
|
| 232 |
+
|
| 233 |
+
# Finalize the function definitions for the gemm routines
|
| 234 |
+
sub_mm_hooks = {
|
| 235 |
+
name: hook
|
| 236 |
+
for name, hook in kernel.render_hooks.items()
|
| 237 |
+
if "FOR_BMM" in name
|
| 238 |
+
}
|
| 239 |
+
result = PartialRender(result, sub_mm_hooks).finalize_all()
|
| 240 |
+
for name in sub_mm_hooks:
|
| 241 |
+
del kernel.render_hooks[name]
|
| 242 |
+
del kernel.args.sizevars[options["b_index"]]
|
| 243 |
+
return result
|
| 244 |
+
|
| 245 |
+
def codegen_single_thread_gemm(self):
|
| 246 |
+
stub = self._template_from_string(GEMM_SINGLE_THREAD_MM_STUB).render(
|
| 247 |
+
self.render_options
|
| 248 |
+
)
|
| 249 |
+
return stub + self._template_from_string(GEMM_TEMPLATE).render(
|
| 250 |
+
{**self.render_options, "num_threads": 1}
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
def codegen_multi_thread_gemm(self):
|
| 254 |
+
stub = self._template_from_string(GEMM_THREADED_MM_STUB).render(
|
| 255 |
+
self.render_options
|
| 256 |
+
)
|
| 257 |
+
return stub + self._template_from_string(GEMM_TEMPLATE).render(
|
| 258 |
+
self.render_options
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
def codegen_gemm_stub_def(self):
|
| 262 |
+
return ""
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_flex_attention_template.py
ADDED
|
@@ -0,0 +1,1081 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import contextlib
|
| 3 |
+
import logging
|
| 4 |
+
import re
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from unittest.mock import patch
|
| 7 |
+
|
| 8 |
+
import sympy
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.utils
|
| 12 |
+
|
| 13 |
+
from ...utils._ordered_set import OrderedSet
|
| 14 |
+
from .. import ir
|
| 15 |
+
from ..ir import TensorBox
|
| 16 |
+
from ..select_algorithm import DataProcessorTemplateWrapper
|
| 17 |
+
from ..utils import parallel_num_threads
|
| 18 |
+
from ..virtualized import V
|
| 19 |
+
from .cpp_template import CppTemplate
|
| 20 |
+
from .cpp_utils import GemmBlocking
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
log = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
# TODO: reuse cpp codegen to generate below pointwise/reduction kernels
|
| 26 |
+
SOFTMAX_FUSIONS = r"""
|
| 27 |
+
// 1) out = exp(a - val)
|
| 28 |
+
// 2) val = sum(out)
|
| 29 |
+
template <typename T1, typename T2>
|
| 30 |
+
inline void {{kernel_name}}_exp_reduce_sum_fusion_kernel(
|
| 31 |
+
T1* a,
|
| 32 |
+
const int& size,
|
| 33 |
+
T2* out,
|
| 34 |
+
T1& val) {
|
| 35 |
+
auto vec_size = at::vec::Vectorized<T1>::size();
|
| 36 |
+
auto vec_max = at::vec::Vectorized<T1>(val);
|
| 37 |
+
T1 tmp_sum = 0;
|
| 38 |
+
auto vec_tmp_sum = at::vec::Vectorized<T1>(tmp_sum);
|
| 39 |
+
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
|
| 40 |
+
auto tmp0 = at::vec::Vectorized<T1>::loadu(a + i);
|
| 41 |
+
auto tmp1 = tmp0 - vec_max;
|
| 42 |
+
auto tmp2 = tmp1.exp_u20();
|
| 43 |
+
vec_tmp_sum += tmp2;
|
| 44 |
+
at::native::_store(out + i, tmp2);
|
| 45 |
+
}
|
| 46 |
+
tmp_sum = at::vec::vec_reduce_all<T1>(
|
| 47 |
+
[](at::vec::Vectorized<T1>& x, at::vec::Vectorized<T1>& y) {
|
| 48 |
+
return x + y;
|
| 49 |
+
},
|
| 50 |
+
vec_tmp_sum);
|
| 51 |
+
for (long i = vec_size * (size / vec_size); i < size; i++) {
|
| 52 |
+
auto tmp0 = a[i];
|
| 53 |
+
auto tmp1 = tmp0 - val;
|
| 54 |
+
auto tmp2 = exp(tmp1);
|
| 55 |
+
tmp_sum += tmp2;
|
| 56 |
+
out[i] = tmp2;
|
| 57 |
+
}
|
| 58 |
+
val = tmp_sum;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
// 1) out = a * scale
|
| 62 |
+
// 2) max = max(out)
|
| 63 |
+
template <typename scalar_t>
|
| 64 |
+
inline void {{kernel_name}}_mul_reduce_max_fusion_kernel(
|
| 65 |
+
const scalar_t* a,
|
| 66 |
+
const scalar_t& scale,
|
| 67 |
+
const int& size,
|
| 68 |
+
scalar_t* out,
|
| 69 |
+
scalar_t& max) {
|
| 70 |
+
auto vec_size = at::vec::Vectorized<scalar_t>::size();
|
| 71 |
+
auto vec_scale = at::vec::Vectorized<scalar_t>(scale);
|
| 72 |
+
scalar_t tmp_max = -std::numeric_limits<scalar_t>::infinity();
|
| 73 |
+
auto vec_tmp_max = at::vec::Vectorized<scalar_t>(tmp_max);
|
| 74 |
+
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
|
| 75 |
+
auto tmp0 = at::vec::Vectorized<scalar_t>::loadu(a + i);
|
| 76 |
+
auto tmp1 = tmp0 * vec_scale;
|
| 77 |
+
vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1);
|
| 78 |
+
at::native::_store(out + i, tmp1);
|
| 79 |
+
}
|
| 80 |
+
for (long i = vec_size * (size / vec_size); i < size; i++) {
|
| 81 |
+
auto tmp0 = a[i];
|
| 82 |
+
auto tmp1 = tmp0 * scale;
|
| 83 |
+
tmp_max = std::max(tmp_max, tmp1);
|
| 84 |
+
out[i] = tmp1;
|
| 85 |
+
}
|
| 86 |
+
max = std::max(
|
| 87 |
+
tmp_max,
|
| 88 |
+
at::vec::vec_reduce_all<scalar_t>(
|
| 89 |
+
[](at::vec::Vectorized<scalar_t>& x, at::vec::Vectorized<scalar_t>& y) {
|
| 90 |
+
return at::vec::maximum(x, y);
|
| 91 |
+
},
|
| 92 |
+
vec_tmp_max));
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
template <typename scalar_t>
|
| 96 |
+
static inline scalar_t* {{kernel_name}}_conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) {
|
| 97 |
+
TORCH_CHECK(ptr2 == nullptr);
|
| 98 |
+
return ptr;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
template <typename scalar_t,
|
| 102 |
+
typename std::enable_if_t<c10::is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 103 |
+
static inline scalar_t* {{kernel_name}}_conditional_data_ptr(float* ptr, scalar_t* ptr2) {
|
| 104 |
+
return ptr2;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
template <typename scalar_t>
|
| 108 |
+
inline void {{kernel_name}}_fill_stub(scalar_t* data, scalar_t val, int64_t size) {
|
| 109 |
+
using Vec = at::vec::Vectorized<scalar_t>;
|
| 110 |
+
Vec data_vec = Vec(val);
|
| 111 |
+
int64_t d = 0;
|
| 112 |
+
for (; d < size - (size % Vec::size()); d += Vec::size()) {
|
| 113 |
+
data_vec.store(data + d);
|
| 114 |
+
}
|
| 115 |
+
#if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
|
| 116 |
+
# pragma unroll
|
| 117 |
+
#endif
|
| 118 |
+
for (; d < size; d++) {
|
| 119 |
+
data[d] = val;
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
// out = a * scale
|
| 124 |
+
template <typename scalar_t>
|
| 125 |
+
inline void {{kernel_name}}_mul_scale_kernel(
|
| 126 |
+
scalar_t* a,
|
| 127 |
+
scalar_t scale,
|
| 128 |
+
int64_t size) {
|
| 129 |
+
auto vec_size = at::vec::Vectorized<scalar_t>::size();
|
| 130 |
+
auto vec_scale = at::vec::Vectorized<scalar_t>(scale);
|
| 131 |
+
for (int64_t i = 0; i < vec_size * (size / vec_size); i += vec_size) {
|
| 132 |
+
auto tmp0 = at::vec::Vectorized<scalar_t>::loadu(a + i);
|
| 133 |
+
auto tmp1 = tmp0 * vec_scale;
|
| 134 |
+
at::native::_store(a + i, tmp1);
|
| 135 |
+
}
|
| 136 |
+
for (int64_t i = vec_size * (size / vec_size); i < size; i++) {
|
| 137 |
+
auto tmp0 = a[i];
|
| 138 |
+
auto tmp1 = tmp0 * scale;
|
| 139 |
+
a[i] = tmp1;
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
BRGEMM_PACK_FUNCTIONS = r"""
|
| 146 |
+
template <typename scalar_t>
|
| 147 |
+
inline void {{kernel_name}}_copy_value_with_pad(
|
| 148 |
+
const scalar_t* value_ptr,
|
| 149 |
+
scalar_t* dst_ptr,
|
| 150 |
+
int64_t rows,
|
| 151 |
+
int64_t cols,
|
| 152 |
+
int64_t prows,
|
| 153 |
+
int64_t pcols,
|
| 154 |
+
int64_t ldi) {
|
| 155 |
+
auto vec_size = at::vec::Vectorized<scalar_t>::size();
|
| 156 |
+
int64_t i = 0;
|
| 157 |
+
for (; i < rows; i++) {
|
| 158 |
+
int64_t j = 0;
|
| 159 |
+
for (; j < cols - (cols % vec_size); j += vec_size) {
|
| 160 |
+
auto vec_v =
|
| 161 |
+
at::vec::Vectorized<scalar_t>::loadu(value_ptr + i * ldi + j);
|
| 162 |
+
vec_v.store(dst_ptr + i * pcols + j);
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
if (j < cols) {
|
| 166 |
+
auto vec_v = at::vec::Vectorized<scalar_t>::loadu(
|
| 167 |
+
value_ptr + i * ldi + j, cols - j);
|
| 168 |
+
vec_v.store(dst_ptr + i * pcols + j, cols - j);
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
// col padding
|
| 172 |
+
auto psize = pcols - cols;
|
| 173 |
+
if (psize > 0) {
|
| 174 |
+
auto zero_vec = at::vec::Vectorized<scalar_t>(0);
|
| 175 |
+
int64_t pj = 0;
|
| 176 |
+
for (; pj < psize - (psize % vec_size); pj += vec_size) {
|
| 177 |
+
zero_vec.store(dst_ptr + i * pcols + cols + pj);
|
| 178 |
+
}
|
| 179 |
+
if (pj < psize) {
|
| 180 |
+
zero_vec.store(dst_ptr + i * pcols + cols + pj, psize - pj);
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
// row padding
|
| 185 |
+
for (; i < prows; i++) {
|
| 186 |
+
auto zero_vec = at::vec::Vectorized<scalar_t>(0);
|
| 187 |
+
int64_t j = 0;
|
| 188 |
+
for (; j < pcols - (pcols % vec_size); j += vec_size) {
|
| 189 |
+
zero_vec.store(dst_ptr + i * pcols + j);
|
| 190 |
+
}
|
| 191 |
+
if (j < pcols) {
|
| 192 |
+
zero_vec.store(dst_ptr + i * pcols + j, pcols - j);
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
MICRO_GEMM_TEMPLATE = r"""
|
| 200 |
+
GEMM_DEFINE
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
ALLOCATE_BUFFER = r"""
|
| 204 |
+
int64_t {{buffer_name}}_dtype_itemsize = c10::is_reduced_floating_point_v<{{buffer_dtype}}> ? 2 : 4;
|
| 205 |
+
auto& {{buffer_name}}_allocator = *at::getCPUAllocator();
|
| 206 |
+
auto {{buffer_name}}_work_data = {{buffer_name}}_allocator.allocate({{buffer_size}}*{{buffer_name}}_dtype_itemsize);
|
| 207 |
+
void* {{buffer_name}}_data_ptr = {{buffer_name}}_work_data.get();
|
| 208 |
+
{{buffer_dtype}}* {{buffer_name}} = ({{buffer_dtype}}*){{buffer_name}}_data_ptr;
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
FLEX_ATTENTION_TEMPLATE = r"""
|
| 212 |
+
{{template.header().getvalue()}}
|
| 213 |
+
#include <ATen/native/cpu/utils.h>
|
| 214 |
+
#include <ATen/native/CPUBlas.h>
|
| 215 |
+
#include <ATen/Context.h>
|
| 216 |
+
{{template.codegen_micro_gemm(kernel.kernel_name)}}
|
| 217 |
+
{{template.codegen_softmax_fusion(kernel.kernel_name)}}
|
| 218 |
+
{{template.codegen_brgemm_pack_function(kernel.kernel_name)}}
|
| 219 |
+
{%- set kernel_args = {"query": query, "key": key, "value": value,
|
| 220 |
+
"kv_num_blocks": kv_num_blocks, "kv_indices": kv_indices,
|
| 221 |
+
"full_kv_num_blocks": full_kv_num_blocks, "full_kv_indices": full_kv_indices } %}
|
| 222 |
+
{%- set kernel_args = template.update_kernel_args(kernel_args) %}
|
| 223 |
+
|
| 224 |
+
extern "C"
|
| 225 |
+
{{kernel.def_kernel(inputs=kernel_args, outputs={"output": output}, extra_sizevars=template.extra_sizevars)}}
|
| 226 |
+
{
|
| 227 |
+
{{ kernel.maybe_codegen_profile() }}
|
| 228 |
+
int64_t qBlockSize = {{qBlockSize}};
|
| 229 |
+
int64_t kvBlockSize = {{kvBlockSize}};
|
| 230 |
+
int64_t num_thread = {{num_thread}};
|
| 231 |
+
|
| 232 |
+
// dtypes of kernel and internal buffers
|
| 233 |
+
using scalar_t = {{kernel.dtype(query)}};
|
| 234 |
+
constexpr bool is_reduced_type = c10::is_reduced_floating_point_v<scalar_t>;
|
| 235 |
+
using accum_t = at::opmath_type<{{kernel.dtype(query)}}>;
|
| 236 |
+
using Vec = at::vec::Vectorized<accum_t>;
|
| 237 |
+
accum_t scaling_factor = {{scale}};
|
| 238 |
+
int64_t batchSize = {{kernel.size(query, 0)}};
|
| 239 |
+
int64_t qSize = {{kernel.size(query, 1)}};
|
| 240 |
+
int64_t num_head = {{kernel.size(query, 2)}};
|
| 241 |
+
int64_t headSize = {{kernel.size(query, 3)}};
|
| 242 |
+
int64_t batchSize_k = {{kernel.size(key, 0)}};
|
| 243 |
+
int64_t num_head_k = {{kernel.size(key, 2)}};
|
| 244 |
+
int64_t headSize_v = {{kernel.size(value, 3)}};
|
| 245 |
+
bool is_broadcast_bs_kv = batchSize != batchSize_k;
|
| 246 |
+
bool is_broadcast_head_kv = num_head != num_head_k;
|
| 247 |
+
int64_t gqa_shards = num_head / num_head_k;
|
| 248 |
+
int64_t bs_shards = batchSize / batchSize_k;
|
| 249 |
+
|
| 250 |
+
int64_t batchSize_kvi = {{kernel.size(kv_indices, 0)}};
|
| 251 |
+
int64_t num_head_kvi = {{kernel.size(kv_indices, 1)}};
|
| 252 |
+
int64_t block_num_kvi = {{kernel.size(kv_indices, 3)}};
|
| 253 |
+
bool is_broadcast_bs_kvi = batchSize != batchSize_kvi;
|
| 254 |
+
bool is_broadcast_head_kvi = num_head != num_head_kvi;
|
| 255 |
+
int64_t gqa_shards_kvi = num_head / num_head_kvi;
|
| 256 |
+
int64_t bs_shards_kvi = batchSize / batchSize_kvi;
|
| 257 |
+
|
| 258 |
+
int64_t kviStrideB = {{kernel.stride(kv_indices, 0)}};
|
| 259 |
+
int64_t kviStrideH = {{kernel.stride(kv_indices, 1)}};
|
| 260 |
+
int64_t kviStrideQ = {{kernel.stride(kv_indices, 2)}};
|
| 261 |
+
|
| 262 |
+
int64_t num_kviStrideB = {{kernel.stride(kv_num_blocks, 0)}};
|
| 263 |
+
int64_t num_kviStrideH = {{kernel.stride(kv_num_blocks, 1)}};
|
| 264 |
+
|
| 265 |
+
{%- if has_full_kv_block %}
|
| 266 |
+
int64_t full_kviStrideB = {{kernel.stride(full_kv_indices, 0)}};
|
| 267 |
+
int64_t full_kviStrideH = {{kernel.stride(full_kv_indices, 1)}};
|
| 268 |
+
int64_t full_kviStrideQ = {{kernel.stride(full_kv_indices, 2)}};
|
| 269 |
+
|
| 270 |
+
int64_t full_num_kviStrideB = {{kernel.stride(full_kv_num_blocks, 0)}};
|
| 271 |
+
int64_t full_num_kviStrideH = {{kernel.stride(full_kv_num_blocks, 1)}};
|
| 272 |
+
auto full_kv_indices_data = full_kv_indices;
|
| 273 |
+
auto full_kv_num_blocks_data = full_kv_num_blocks;
|
| 274 |
+
{%- endif %}
|
| 275 |
+
|
| 276 |
+
auto kv_num_blocks_data = kv_num_blocks;
|
| 277 |
+
auto kv_indices_data = kv_indices;
|
| 278 |
+
|
| 279 |
+
// Strides
|
| 280 |
+
int64_t qStrideB = {{kernel.stride(query, 0)}};
|
| 281 |
+
int64_t qStrideM = {{kernel.stride(query, 1)}};
|
| 282 |
+
int64_t qStrideH = {{kernel.stride(query, 2)}};
|
| 283 |
+
int64_t kStrideB = {{kernel.stride(key, 0)}};
|
| 284 |
+
int64_t kStrideN = {{kernel.stride(key, 1)}};
|
| 285 |
+
int64_t kStrideH = {{kernel.stride(key, 2)}};
|
| 286 |
+
int64_t vStrideB = {{kernel.stride(value, 0)}};
|
| 287 |
+
int64_t vStrideN = {{kernel.stride(value, 1)}};
|
| 288 |
+
int64_t vStrideH = {{kernel.stride(value, 2)}};
|
| 289 |
+
int64_t oStrideB = {{kernel.stride(output, 0)}};
|
| 290 |
+
int64_t oStrideM = {{kernel.stride(output, 2)}};
|
| 291 |
+
int64_t oStrideH = {{kernel.stride(output, 1)}};
|
| 292 |
+
|
| 293 |
+
int64_t kvSize = {{kernel.size(key, 1)}};
|
| 294 |
+
|
| 295 |
+
int64_t qSplitSize = qBlockSize;
|
| 296 |
+
int64_t kvSplitSize = kvBlockSize;
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
qSplitSize = qSplitSize > qSize ? qSize : qSplitSize;
|
| 300 |
+
kvSplitSize = kvSplitSize > kvSize ? kvSize : kvSplitSize;
|
| 301 |
+
int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize;
|
| 302 |
+
int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize;
|
| 303 |
+
int64_t kvTail = (kvSize - 1) % kvSplitSize + 1;
|
| 304 |
+
|
| 305 |
+
bool need_pack = false;
|
| 306 |
+
// Whether pack is needed for BFloat16/Half
|
| 307 |
+
if (is_reduced_type) {
|
| 308 |
+
// check platform ability
|
| 309 |
+
need_pack = std::is_same_v<scalar_t, at::BFloat16> ? at::native::cpublas::could_pack(at::kBFloat16)
|
| 310 |
+
: at::native::cpublas::could_pack(at::kHalf);
|
| 311 |
+
}
|
| 312 |
+
if (need_pack) {
|
| 313 |
+
// When the number of gemm is greater than the number of pack,
|
| 314 |
+
// the pack overhead can be overlapped.
|
| 315 |
+
int64_t thresh_size = 64;
|
| 316 |
+
need_pack = kvSize >= thresh_size && qSize >= thresh_size;
|
| 317 |
+
if (need_pack) {
|
| 318 |
+
double pack_size = batchSize * num_head * kvSize * headSize;
|
| 319 |
+
double qs_per_thread = (batchSize * num_head * qSlice + num_thread - 1) / num_thread;
|
| 320 |
+
double gemm_size_per_thread = qs_per_thread * qSplitSize * kvSize * headSize;
|
| 321 |
+
need_pack = gemm_size_per_thread / pack_size >= 4;
|
| 322 |
+
}
|
| 323 |
+
}
|
| 324 |
+
// Pad is needed for packing when K is not even
|
| 325 |
+
bool headSize_even = headSize % 2 == 0;
|
| 326 |
+
int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize;
|
| 327 |
+
int64_t ekvSplitSize = need_pack && (kvSplitSize % 2 != 0) ? kvSplitSize + 1 : kvSplitSize;
|
| 328 |
+
int64_t ekvTail = need_pack && (kvTail % 2 != 0) ? kvTail + 1 : kvTail;
|
| 329 |
+
int64_t kv_padding_size = (kvSize - 1) / kvSplitSize * ekvSplitSize + ekvTail;
|
| 330 |
+
|
| 331 |
+
// Allocate per thread temp buf (accumulate type)
|
| 332 |
+
int64_t _size_per_thread =
|
| 333 |
+
/* qk */ qSplitSize * kvSplitSize +
|
| 334 |
+
/* qk_max */ qSplitSize +
|
| 335 |
+
/* qk_sum */ qSplitSize +
|
| 336 |
+
/* dst */ qSplitSize * headSize_v;
|
| 337 |
+
|
| 338 |
+
// Inputs/outputs buffers
|
| 339 |
+
const scalar_t* q_data = query;
|
| 340 |
+
const scalar_t* k_data = key;
|
| 341 |
+
const scalar_t* v_data = value;
|
| 342 |
+
scalar_t* out_data = output;
|
| 343 |
+
|
| 344 |
+
// Buffers to store accum results, padding query and transpose/packing key/value
|
| 345 |
+
{{template.codegen_allocate_buffer("buf_data", "accum_t", "num_thread*_size_per_thread")}}
|
| 346 |
+
{{template.codegen_allocate_buffer("buf_reduced_data", "scalar_t", "num_thread*qSplitSize*ekvSplitSize")}}
|
| 347 |
+
{{template.codegen_allocate_buffer("key_reorder_ptr", "scalar_t", "batchSize_k*num_head_k*eheadSize*kvSize")}}
|
| 348 |
+
{{template.codegen_allocate_buffer("value_reorder_ptr", "scalar_t", "batchSize_k*num_head_k*kv_padding_size*headSize_v")}}
|
| 349 |
+
{{template.codegen_allocate_buffer("transpose_buffer_ptr", "scalar_t", "num_thread*kvSplitSize*headSize")}}
|
| 350 |
+
{{template.codegen_allocate_buffer("query_padding_ptr", "scalar_t", "num_thread*qSplitSize*eheadSize")}}
|
| 351 |
+
if (need_pack) {
|
| 352 |
+
// Pack K, V
|
| 353 |
+
at::parallel_for(0, batchSize_k * num_head_k * kvSlice, 1, [&](int64_t begin, int64_t end) {
|
| 354 |
+
int ompIdx = at::get_thread_num();
|
| 355 |
+
int64_t i = 0, j = 0, l = 0, n = 0;
|
| 356 |
+
scalar_t* transpose_ptr = need_pack? transpose_buffer_ptr + ompIdx * kvSplitSize * headSize : nullptr;
|
| 357 |
+
at::native::data_index_init(begin, i, batchSize_k, j, num_head_k, l, kvSlice);
|
| 358 |
+
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
|
| 359 |
+
n = l * kvSplitSize;
|
| 360 |
+
int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n);
|
| 361 |
+
auto k_addr =
|
| 362 |
+
k_data + i * kStrideB + j * kStrideH + n * kStrideN;
|
| 363 |
+
auto v_addr =
|
| 364 |
+
v_data + i * vStrideB + j * vStrideH + n * vStrideN;
|
| 365 |
+
// transpose [cur_kvSplitSize, headSize] -> [headSize, cur_kvSplitSize]
|
| 366 |
+
at::native::utils::transpose<uint16_t>(
|
| 367 |
+
cur_kvSplitSize,
|
| 368 |
+
headSize,
|
| 369 |
+
/* src_ptr */
|
| 370 |
+
reinterpret_cast<const uint16_t*>(k_addr),
|
| 371 |
+
/* ld_src */ kStrideN,
|
| 372 |
+
/* dst */ reinterpret_cast<uint16_t*>(transpose_ptr),
|
| 373 |
+
/* ld_dst */ cur_kvSplitSize);
|
| 374 |
+
|
| 375 |
+
// Pack [headSize, cur_kvSplitSize]
|
| 376 |
+
at::vec::pack_vnni2(
|
| 377 |
+
/* src */ reinterpret_cast<const uint16_t*>(transpose_ptr),
|
| 378 |
+
/* dst */ reinterpret_cast<uint16_t*>(key_reorder_ptr + i * num_head_k * eheadSize * kvSize +
|
| 379 |
+
j * eheadSize * kvSize + n * eheadSize),
|
| 380 |
+
/* ld_src */ cur_kvSplitSize,
|
| 381 |
+
/* K */ headSize,
|
| 382 |
+
/* N */ cur_kvSplitSize);
|
| 383 |
+
|
| 384 |
+
// Pack [cur_kvSplitSize, headSize_v]
|
| 385 |
+
at::vec::pack_vnni2(
|
| 386 |
+
/* src */ reinterpret_cast<const uint16_t*>(v_addr),
|
| 387 |
+
/* dst */ reinterpret_cast<uint16_t*>(value_reorder_ptr +
|
| 388 |
+
i * num_head_k * kv_padding_size * headSize_v +
|
| 389 |
+
j * kv_padding_size * headSize_v + n * headSize_v),
|
| 390 |
+
/* ld_src */ vStrideN,
|
| 391 |
+
/* K */ cur_kvSplitSize,
|
| 392 |
+
/* N */ headSize_v);
|
| 393 |
+
// Move to the next query
|
| 394 |
+
at::native::data_index_step(i, batchSize_k, j, num_head_k, l, kvSlice);
|
| 395 |
+
}
|
| 396 |
+
});
|
| 397 |
+
}
|
| 398 |
+
// Attention loop below
|
| 399 |
+
at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
|
| 400 |
+
int64_t i = 0, j = 0, k = 0;
|
| 401 |
+
at::native::data_index_init(begin, i, batchSize, j, num_head, k, qSlice);
|
| 402 |
+
int ompIdx = at::get_thread_num();
|
| 403 |
+
accum_t* buf_ptr = buf_data + ompIdx * _size_per_thread;
|
| 404 |
+
accum_t* qk_data = buf_ptr;
|
| 405 |
+
accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize;
|
| 406 |
+
accum_t* qk_sum_data = qk_max_data + qSplitSize;
|
| 407 |
+
accum_t* dst_data = qk_sum_data + qSplitSize;
|
| 408 |
+
scalar_t *qk_reduced_data =
|
| 409 |
+
is_reduced_type
|
| 410 |
+
? buf_reduced_data + ompIdx * qSplitSize * ekvSplitSize
|
| 411 |
+
: nullptr;
|
| 412 |
+
scalar_t* query_t_padding_ptr = (!headSize_even && need_pack)
|
| 413 |
+
? query_padding_ptr + ompIdx * qSplitSize * eheadSize
|
| 414 |
+
: nullptr;
|
| 415 |
+
|
| 416 |
+
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
|
| 417 |
+
auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i;
|
| 418 |
+
auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j;
|
| 419 |
+
auto kv_logical_num_data = kv_num_blocks_data + i_kvi * num_kviStrideB +
|
| 420 |
+
j_kvi * num_kviStrideH + k;
|
| 421 |
+
int kv_indice_num = *kv_logical_num_data;
|
| 422 |
+
std::vector<int> kv_indice_list(kv_indice_num);
|
| 423 |
+
for(int kv_i = 0; kv_i < kv_indice_num; kv_i++){
|
| 424 |
+
auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
|
| 425 |
+
j_kvi * kviStrideH + k*kviStrideQ + kv_i;
|
| 426 |
+
kv_indice_list[kv_i] = *kv_logical_data;
|
| 427 |
+
}
|
| 428 |
+
bool is_skip_kv = kv_indice_num > 0 ? false : true;
|
| 429 |
+
{%- if has_full_kv_block %}
|
| 430 |
+
auto full_kv_logical_num_data = full_kv_num_blocks_data + i_kvi * num_kviStrideB +
|
| 431 |
+
j_kvi * num_kviStrideH + k;
|
| 432 |
+
int full_kv_indice_num = *full_kv_logical_num_data;
|
| 433 |
+
std::vector<int> full_kv_indice_list(full_kv_indice_num);
|
| 434 |
+
for(int kv_i = 0; kv_i < full_kv_indice_num; kv_i++){
|
| 435 |
+
auto full_kv_logical_data = full_kv_indices_data + i_kvi * full_kviStrideB +
|
| 436 |
+
j_kvi * full_kviStrideH + k*full_kviStrideQ + kv_i;
|
| 437 |
+
full_kv_indice_list[kv_i] = *full_kv_logical_data;
|
| 438 |
+
}
|
| 439 |
+
is_skip_kv = kv_indice_num + full_kv_indice_num > 0 ? false : true;
|
| 440 |
+
{%- endif %}
|
| 441 |
+
int64_t m = k * qSplitSize;
|
| 442 |
+
int64_t cur_qSplitSize = std::min(qSplitSize, qSize - m);
|
| 443 |
+
if (!is_skip_kv){
|
| 444 |
+
// Initialize max and sum
|
| 445 |
+
{{kernel.kernel_name}}_fill_stub(qk_max_data,
|
| 446 |
+
-std::numeric_limits<accum_t>::infinity(), cur_qSplitSize);
|
| 447 |
+
{{kernel.kernel_name}}_fill_stub(qk_sum_data,
|
| 448 |
+
static_cast<accum_t>(0), cur_qSplitSize);
|
| 449 |
+
|
| 450 |
+
if (!headSize_even && need_pack) {
|
| 451 |
+
// Pad query if headSize is not even
|
| 452 |
+
{{kernel.kernel_name}}_copy_value_with_pad<scalar_t>(
|
| 453 |
+
q_data + i * qStrideB + j * qStrideH + m * qStrideM,
|
| 454 |
+
query_t_padding_ptr,
|
| 455 |
+
cur_qSplitSize,
|
| 456 |
+
headSize,
|
| 457 |
+
cur_qSplitSize,
|
| 458 |
+
eheadSize,
|
| 459 |
+
qStrideM
|
| 460 |
+
);
|
| 461 |
+
}
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
{%- if has_full_kv_block %}
|
| 465 |
+
for (int64_t n_idx = 0; n_idx < kv_indice_num + full_kv_indice_num ; n_idx += 1) {
|
| 466 |
+
auto n = n_idx < kv_indice_num ? kv_indice_list[n_idx]*kvSplitSize : full_kv_indice_list[n_idx - kv_indice_num]*kvSplitSize;
|
| 467 |
+
{%- else %}
|
| 468 |
+
for (int64_t n_idx = 0; n_idx < kv_indice_num ; n_idx += 1) {
|
| 469 |
+
auto n = kv_indice_list[n_idx]*kvSplitSize;
|
| 470 |
+
{%- endif %}
|
| 471 |
+
|
| 472 |
+
auto cur_n = n/kvSplitSize;
|
| 473 |
+
int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n);
|
| 474 |
+
int64_t cur_ekvSplitSize = (need_pack && cur_kvSplitSize % 2 != 0) ? cur_kvSplitSize + 1 : cur_kvSplitSize;
|
| 475 |
+
|
| 476 |
+
// Calculate scale * q @ k.T
|
| 477 |
+
auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i;
|
| 478 |
+
auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j;
|
| 479 |
+
|
| 480 |
+
if (!need_pack) {
|
| 481 |
+
auto k_addr =
|
| 482 |
+
k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
|
| 483 |
+
|
| 484 |
+
{{kernel.kernel_name}}_kernel_micro_gemm_transpose_b<static_cast<bool>(false)>(
|
| 485 |
+
q_data + i * qStrideB + j * qStrideH +
|
| 486 |
+
m * qStrideM,
|
| 487 |
+
k_addr,
|
| 488 |
+
qk_data,
|
| 489 |
+
cur_qSplitSize,
|
| 490 |
+
cur_kvSplitSize,
|
| 491 |
+
headSize,
|
| 492 |
+
qStrideM,
|
| 493 |
+
kStrideN,
|
| 494 |
+
cur_kvSplitSize);
|
| 495 |
+
|
| 496 |
+
} else {
|
| 497 |
+
at::native::cpublas::brgemm(
|
| 498 |
+
cur_qSplitSize,
|
| 499 |
+
cur_kvSplitSize,
|
| 500 |
+
eheadSize,
|
| 501 |
+
headSize_even ? qStrideM : eheadSize,
|
| 502 |
+
cur_kvSplitSize,
|
| 503 |
+
cur_kvSplitSize,
|
| 504 |
+
false,
|
| 505 |
+
!headSize_even
|
| 506 |
+
? query_t_padding_ptr
|
| 507 |
+
: q_data + i * qStrideB + j * qStrideH + m * qStrideM,
|
| 508 |
+
key_reorder_ptr + i_kv * num_head_k * eheadSize * kvSize +
|
| 509 |
+
j_kv * eheadSize * kvSize + n * eheadSize,
|
| 510 |
+
qk_data,
|
| 511 |
+
need_pack);
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
{{kernel.kernel_name}}_mul_scale_kernel<accum_t>(qk_data, scaling_factor, cur_qSplitSize*cur_kvSplitSize);
|
| 515 |
+
|
| 516 |
+
{%- if score_mod and mask_mod %}
|
| 517 |
+
// TODO: reduce the number of calls of q_idx and kv_idx initialization
|
| 518 |
+
std::vector<int64_t> q_idx(cur_qSplitSize);
|
| 519 |
+
for (int64_t i = 0; i < cur_qSplitSize; ++i) {
|
| 520 |
+
q_idx[i] = m + i;
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
std::vector<int64_t> kv_idx(cur_kvSplitSize);
|
| 524 |
+
for (int64_t i = 0; i < cur_kvSplitSize; ++i) {
|
| 525 |
+
kv_idx[i] = n + i;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
std::vector<int64_t> b_idx = {i};
|
| 529 |
+
std::vector<int64_t> h_idx = {j};
|
| 530 |
+
|
| 531 |
+
accum_t* in_ptr0 = qk_data;
|
| 532 |
+
|
| 533 |
+
auto in_ptr1 = b_idx.data();
|
| 534 |
+
auto in_ptr2 = h_idx.data();
|
| 535 |
+
auto in_ptr3 = q_idx.data();
|
| 536 |
+
auto in_ptr4 = kv_idx.data();
|
| 537 |
+
|
| 538 |
+
// apply score mod function
|
| 539 |
+
{
|
| 540 |
+
{{ template.generate_other_buffer("score_others", 0, "len_score_other", kernel.args) }}
|
| 541 |
+
accum_t* out_ptr{{score_buf_idx}} = in_ptr0;
|
| 542 |
+
{{ template.modification(score_mod, score_buf_name, score_buf_idx)|indent(12, false) }}
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
if ((std::find(kv_indice_list.begin(), kv_indice_list.end(), cur_n) != kv_indice_list.end()) ){
|
| 546 |
+
// Apply block mask, fill unused with -inf
|
| 547 |
+
{
|
| 548 |
+
{{ template.generate_other_buffer("mask_others", -1, "len_mask_other", kernel.args) }}
|
| 549 |
+
accum_t* out_ptr{{mask_buf_idx}} = in_ptr0;
|
| 550 |
+
{{ template.modification(mask_mod, mask_buf_name, mask_buf_idx)|indent(12, false) }}
|
| 551 |
+
}
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
{%- endif %}
|
| 555 |
+
// Update coefficients with Softmax
|
| 556 |
+
accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0;
|
| 557 |
+
for (int64_t row = 0; row < cur_qSplitSize; ++row) {
|
| 558 |
+
// apply scaling factor and max per row in fusion
|
| 559 |
+
{{kernel.kernel_name}}_mul_reduce_max_fusion_kernel(
|
| 560 |
+
qk_data + row * cur_kvSplitSize,
|
| 561 |
+
static_cast<accum_t>(1),
|
| 562 |
+
cur_kvSplitSize,
|
| 563 |
+
qk_data + row * cur_kvSplitSize,
|
| 564 |
+
tmp_max);
|
| 565 |
+
tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
|
| 566 |
+
if (tmp_max == -std::numeric_limits<accum_t>::infinity()) {
|
| 567 |
+
// to avoid `nan = exp2f(-inf - (-inf))`
|
| 568 |
+
{{kernel.kernel_name}}_fill_stub(
|
| 569 |
+
{{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data) + row * cur_ekvSplitSize,
|
| 570 |
+
static_cast<scalar_t>(0), cur_kvSplitSize);
|
| 571 |
+
} else {
|
| 572 |
+
tmp_sum = tmp_max;
|
| 573 |
+
// qk <- exp(qk - max) and sum per row
|
| 574 |
+
{{kernel.kernel_name}}_exp_reduce_sum_fusion_kernel(
|
| 575 |
+
qk_data + row * cur_kvSplitSize, cur_kvSplitSize,
|
| 576 |
+
{{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data) + row * cur_ekvSplitSize,
|
| 577 |
+
tmp_sum);
|
| 578 |
+
// exp_tmp <- exp(max[row] - max)
|
| 579 |
+
exp_tmp = std::exp(qk_max_data[row] - tmp_max);
|
| 580 |
+
// sum[row] <- sum + exp_tmp * sum[row]
|
| 581 |
+
qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
|
| 582 |
+
// max[row] <- max
|
| 583 |
+
qk_max_data[row] = tmp_max;
|
| 584 |
+
// dst <- dst * exp_tmp
|
| 585 |
+
if (n_idx > 0) {
|
| 586 |
+
at::vec::map<accum_t>(
|
| 587 |
+
[exp_tmp](Vec x) { return x * Vec(exp_tmp); },
|
| 588 |
+
dst_data + row * headSize_v,
|
| 589 |
+
dst_data + row * headSize_v,
|
| 590 |
+
headSize_v);
|
| 591 |
+
}
|
| 592 |
+
}
|
| 593 |
+
if (need_pack && cur_kvSplitSize % 2 != 0) {
|
| 594 |
+
// Pad: [qSplitSize, cur_kvSplitSize] -> [qSplitSize, cur_kvSplitSize + 1]
|
| 595 |
+
*(qk_reduced_data + row * (1 + cur_kvSplitSize) + cur_kvSplitSize) = scalar_t(0);
|
| 596 |
+
}
|
| 597 |
+
}
|
| 598 |
+
// Calculate Softmax(q @ k.T) @ v
|
| 599 |
+
if (!need_pack) {
|
| 600 |
+
auto v_addr =
|
| 601 |
+
v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN;
|
| 602 |
+
// Fallback Half brgemm is slower than micro gemm
|
| 603 |
+
if (!std::is_same_v<scalar_t, at::Half>) {
|
| 604 |
+
at::native::cpublas::brgemm(
|
| 605 |
+
cur_qSplitSize,
|
| 606 |
+
headSize_v,
|
| 607 |
+
cur_ekvSplitSize,
|
| 608 |
+
cur_ekvSplitSize,
|
| 609 |
+
vStrideN,
|
| 610 |
+
headSize_v,
|
| 611 |
+
n_idx > 0,
|
| 612 |
+
{{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data),
|
| 613 |
+
v_addr,
|
| 614 |
+
dst_data,
|
| 615 |
+
need_pack);
|
| 616 |
+
} else {
|
| 617 |
+
if (n_idx > 0) {
|
| 618 |
+
{{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(true)>(
|
| 619 |
+
{{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data),
|
| 620 |
+
v_addr,
|
| 621 |
+
dst_data,
|
| 622 |
+
cur_qSplitSize,
|
| 623 |
+
headSize_v,
|
| 624 |
+
cur_ekvSplitSize,
|
| 625 |
+
cur_ekvSplitSize,
|
| 626 |
+
vStrideN,
|
| 627 |
+
headSize_v);
|
| 628 |
+
} else {
|
| 629 |
+
{{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(false)>(
|
| 630 |
+
{{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data),
|
| 631 |
+
v_addr,
|
| 632 |
+
dst_data,
|
| 633 |
+
cur_qSplitSize,
|
| 634 |
+
headSize_v,
|
| 635 |
+
cur_ekvSplitSize,
|
| 636 |
+
cur_ekvSplitSize,
|
| 637 |
+
vStrideN,
|
| 638 |
+
headSize_v);
|
| 639 |
+
}
|
| 640 |
+
}
|
| 641 |
+
} else {
|
| 642 |
+
int64_t psize = n / kvSplitSize * ekvSplitSize;
|
| 643 |
+
at::native::cpublas::brgemm(
|
| 644 |
+
cur_qSplitSize,
|
| 645 |
+
headSize_v,
|
| 646 |
+
cur_ekvSplitSize,
|
| 647 |
+
cur_ekvSplitSize,
|
| 648 |
+
headSize_v,
|
| 649 |
+
headSize_v,
|
| 650 |
+
n_idx > 0,
|
| 651 |
+
qk_reduced_data,
|
| 652 |
+
value_reorder_ptr +
|
| 653 |
+
i_kv * num_head_k * kv_padding_size * headSize_v +
|
| 654 |
+
j_kv * kv_padding_size * headSize_v + psize * headSize_v,
|
| 655 |
+
dst_data,
|
| 656 |
+
need_pack);
|
| 657 |
+
}
|
| 658 |
+
}
|
| 659 |
+
|
| 660 |
+
// dst <- dst / sum[row]
|
| 661 |
+
// reorder MHA output with strides
|
| 662 |
+
for (int64_t row = 0; row < cur_qSplitSize; ++row) {
|
| 663 |
+
// Row sums for full masked out rows are 0, we set them to 1
|
| 664 |
+
// in order to avoid NaNs in the output and instead set fully
|
| 665 |
+
// masked out rows to 0
|
| 666 |
+
qk_max_data[row] = qk_max_data[row] == -std::numeric_limits<accum_t>::infinity() ? 0 : qk_max_data[row];
|
| 667 |
+
qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row];
|
| 668 |
+
accum_t sum_reciprocal = 1 / qk_sum_data[row];
|
| 669 |
+
at::vec::map<scalar_t>(
|
| 670 |
+
[sum_reciprocal, is_skip_kv](Vec x) { return is_skip_kv ? Vec(0.0) : x * Vec(sum_reciprocal); },
|
| 671 |
+
out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM,
|
| 672 |
+
dst_data + row * headSize_v,
|
| 673 |
+
headSize_v);
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
// Move to the next query
|
| 677 |
+
at::native::data_index_step(i, batchSize, j, num_head, k, qSlice);
|
| 678 |
+
}
|
| 679 |
+
|
| 680 |
+
at::native::cpublas::brgemm_release(need_pack);
|
| 681 |
+
|
| 682 |
+
});
|
| 683 |
+
}
|
| 684 |
+
"""
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
class CppFlexAttentionTemplate(CppTemplate):
|
| 688 |
+
def __init__(
|
| 689 |
+
self,
|
| 690 |
+
input_nodes,
|
| 691 |
+
layout: ir.Layout,
|
| 692 |
+
scale,
|
| 693 |
+
score_mod,
|
| 694 |
+
mask_mod,
|
| 695 |
+
kv_block_size,
|
| 696 |
+
q_block_size,
|
| 697 |
+
has_other_buffer,
|
| 698 |
+
no_full_kv_block,
|
| 699 |
+
fake_buffers,
|
| 700 |
+
len_score_other,
|
| 701 |
+
len_mask_other,
|
| 702 |
+
kernel_input_name_to_buffer,
|
| 703 |
+
block_vars,
|
| 704 |
+
) -> None:
|
| 705 |
+
assert layout.dtype in [torch.float, torch.bfloat16, torch.float16]
|
| 706 |
+
super().__init__("flex_attention", input_nodes, layout, parallel_num_threads())
|
| 707 |
+
self.scale = scale
|
| 708 |
+
self.score_mod = score_mod
|
| 709 |
+
self.mask_mod = mask_mod
|
| 710 |
+
self.score_buf_name = (
|
| 711 |
+
V.graph.register_buffer(self.score_mod) if self.score_mod else None
|
| 712 |
+
)
|
| 713 |
+
self.mask_buf_name = (
|
| 714 |
+
V.graph.register_buffer(self.mask_mod) if self.mask_mod else None
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
def get_idx(buf_name):
|
| 718 |
+
match = re.search(r"\d+", buf_name)
|
| 719 |
+
assert match, f"incorrect score buf name: {buf_name}"
|
| 720 |
+
return match.group()
|
| 721 |
+
|
| 722 |
+
self.score_buf_idx = (
|
| 723 |
+
get_idx(self.score_buf_name) if self.score_buf_name else None
|
| 724 |
+
)
|
| 725 |
+
self.mask_buf_idx = get_idx(self.mask_buf_name) if self.mask_buf_name else None
|
| 726 |
+
self.kv_block_size = kv_block_size
|
| 727 |
+
self.q_block_size = q_block_size
|
| 728 |
+
self.has_other_buffer = has_other_buffer
|
| 729 |
+
self.no_full_kv_block = no_full_kv_block
|
| 730 |
+
self.other_buffer_input_offset = 2
|
| 731 |
+
if self.no_full_kv_block:
|
| 732 |
+
self.other_buffer_input_offset = 0
|
| 733 |
+
self.fake_buffers = fake_buffers
|
| 734 |
+
self.len_score_other = len_score_other
|
| 735 |
+
self.len_mask_other = len_mask_other
|
| 736 |
+
self.kernel_input_name_to_buffer = kernel_input_name_to_buffer
|
| 737 |
+
self.block_vars = block_vars
|
| 738 |
+
self.extra_sizevars = list(
|
| 739 |
+
OrderedSet(
|
| 740 |
+
val
|
| 741 |
+
for val in self.kernel_input_name_to_buffer.values()
|
| 742 |
+
if isinstance(val, sympy.Symbol)
|
| 743 |
+
)
|
| 744 |
+
)
|
| 745 |
+
self.other_buf_start_idx = 5
|
| 746 |
+
self.score_mod_other_buffers = (
|
| 747 |
+
self.input_nodes[
|
| 748 |
+
self.other_buf_start_idx
|
| 749 |
+
+ self.other_buffer_input_offset : self.other_buf_start_idx
|
| 750 |
+
+ self.other_buffer_input_offset
|
| 751 |
+
+ self.len_score_other
|
| 752 |
+
]
|
| 753 |
+
if self.has_other_buffer
|
| 754 |
+
else None
|
| 755 |
+
)
|
| 756 |
+
self.mask_mod_other_buffers = (
|
| 757 |
+
self.input_nodes[
|
| 758 |
+
self.other_buf_start_idx
|
| 759 |
+
+ self.other_buffer_input_offset
|
| 760 |
+
+ self.len_score_other :
|
| 761 |
+
]
|
| 762 |
+
if self.has_other_buffer
|
| 763 |
+
else None
|
| 764 |
+
)
|
| 765 |
+
self.other_ptr_data = {} # type: ignore[var-annotated]
|
| 766 |
+
|
| 767 |
+
def update_kernel_args(self, kernel_args):
|
| 768 |
+
kernel_args.update(
|
| 769 |
+
{
|
| 770 |
+
key: value
|
| 771 |
+
for key, value in self.kernel_input_name_to_buffer.items()
|
| 772 |
+
if not isinstance(value, sympy.Symbol)
|
| 773 |
+
}
|
| 774 |
+
)
|
| 775 |
+
return kernel_args
|
| 776 |
+
|
| 777 |
+
def generate_other_buffer(self, buf_list, start_offset, len_attr, kernel_args):
|
| 778 |
+
kernel_input_name_to_buffer_name = {
|
| 779 |
+
key: value if isinstance(value, sympy.Symbol) else value.get_name()
|
| 780 |
+
for key, value in self.kernel_input_name_to_buffer.items()
|
| 781 |
+
}
|
| 782 |
+
|
| 783 |
+
def get_arg(name):
|
| 784 |
+
return kernel_input_name_to_buffer_name.get(name)
|
| 785 |
+
|
| 786 |
+
def get_arg_name(name):
|
| 787 |
+
if isinstance(get_arg(name), sympy.Symbol):
|
| 788 |
+
return kernel_args.sizevars.get(get_arg(name))
|
| 789 |
+
return kernel_args.input_buffers.get(get_arg(name))
|
| 790 |
+
|
| 791 |
+
if not self.has_other_buffer:
|
| 792 |
+
return ""
|
| 793 |
+
|
| 794 |
+
if start_offset == -1:
|
| 795 |
+
start_offset = getattr(self, len_attr)
|
| 796 |
+
|
| 797 |
+
length = getattr(self, len_attr)
|
| 798 |
+
for i in range(length):
|
| 799 |
+
pointer = f"in_ptr{self.other_buf_start_idx + start_offset + i}"
|
| 800 |
+
buffer_key = f"{buf_list}_{i}"
|
| 801 |
+
if pointer not in self.other_ptr_data:
|
| 802 |
+
self.other_ptr_data[pointer] = (
|
| 803 |
+
get_arg_name(buffer_key),
|
| 804 |
+
get_arg(buffer_key),
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
return "\n".join(
|
| 808 |
+
f"auto {ptr} = {name};" for ptr, (name, _) in self.other_ptr_data.items()
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
def modification(self, subgraph_buffer, output_name, output_idx):
|
| 812 |
+
assert isinstance(subgraph_buffer, ir.ComputedBuffer)
|
| 813 |
+
subgraph_buffer_data = subgraph_buffer.data
|
| 814 |
+
from ..loop_body import LoopBody
|
| 815 |
+
from ..utils import sympy_index_symbol_with_prefix, SymT
|
| 816 |
+
from ..virtualized import V
|
| 817 |
+
from .cpp import CppKernelProxy, KernelGroup
|
| 818 |
+
|
| 819 |
+
kernel_group = KernelGroup()
|
| 820 |
+
kernel_input_args = {
|
| 821 |
+
"score": "in_ptr0",
|
| 822 |
+
"b": "in_ptr1",
|
| 823 |
+
"h": "in_ptr2",
|
| 824 |
+
"q_idx": "in_ptr3",
|
| 825 |
+
"kv_idx": "in_ptr4",
|
| 826 |
+
}
|
| 827 |
+
if self.has_other_buffer:
|
| 828 |
+
kernel_input_args.update(
|
| 829 |
+
{arg: ptr for ptr, (_, arg) in self.other_ptr_data.items()}
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
kernel_output_args = {output_name: f"out_ptr{output_idx}"}
|
| 833 |
+
|
| 834 |
+
args = kernel_group.args
|
| 835 |
+
for name, inp in kernel_input_args.items():
|
| 836 |
+
args.input_buffers[name] = inp
|
| 837 |
+
|
| 838 |
+
for name, inp in kernel_output_args.items():
|
| 839 |
+
args.output_buffers[name] = inp
|
| 840 |
+
|
| 841 |
+
for name in self.extra_sizevars:
|
| 842 |
+
args.sizevars[name] = f"k{name}"
|
| 843 |
+
|
| 844 |
+
kernel_group.args = args
|
| 845 |
+
|
| 846 |
+
cpp_kernel_proxy = CppKernelProxy(kernel_group)
|
| 847 |
+
bodies = []
|
| 848 |
+
var_sizes_list = []
|
| 849 |
+
var_sizes = tuple(subgraph_buffer.get_size())
|
| 850 |
+
var_ranges = {
|
| 851 |
+
sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
|
| 852 |
+
for i, sz in enumerate(var_sizes)
|
| 853 |
+
}
|
| 854 |
+
|
| 855 |
+
dst_layout = subgraph_buffer.get_layout()
|
| 856 |
+
output_index = dst_layout.make_indexer()([*var_ranges.keys()])
|
| 857 |
+
|
| 858 |
+
def fn(*args):
|
| 859 |
+
V.ops.store(
|
| 860 |
+
output_name,
|
| 861 |
+
output_index,
|
| 862 |
+
subgraph_buffer_data.make_loader()(args).value,
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
body = LoopBody(
|
| 866 |
+
fn,
|
| 867 |
+
(list(var_ranges.keys())),
|
| 868 |
+
var_ranges,
|
| 869 |
+
list(var_ranges.keys()),
|
| 870 |
+
tuple(),
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
from ..loop_body import MemoryUsageType
|
| 874 |
+
|
| 875 |
+
assert all(
|
| 876 |
+
mem.buffer_name in kernel_group.args.input_buffers
|
| 877 |
+
for mem in body.memory_usage[MemoryUsageType.LOAD]
|
| 878 |
+
), (
|
| 879 |
+
"All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers"
|
| 880 |
+
)
|
| 881 |
+
|
| 882 |
+
bodies.append(body)
|
| 883 |
+
var_sizes_list.append((var_sizes, ()))
|
| 884 |
+
|
| 885 |
+
cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
|
| 886 |
+
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
|
| 887 |
+
output_code = kernel_group.loops_code.getvalue()
|
| 888 |
+
|
| 889 |
+
var_q_symbol, var_kv_symbol = self.block_vars
|
| 890 |
+
# See [Note] Handle the case where the split sizes are not statically known.
|
| 891 |
+
# We don't know the value of qBlockSize and rkvBlockSize during compilation time
|
| 892 |
+
# thus we've represented them by symbols.
|
| 893 |
+
# We change the symbol strings back to "cur_qSplitSize" and "cur_kvSplitSize"
|
| 894 |
+
# in the generated code thus they'll be filled with the real value during runtime.
|
| 895 |
+
if var_q_symbol in kernel_group.args.sizevars:
|
| 896 |
+
output_code = output_code.replace(
|
| 897 |
+
kernel_group.args.sizevars[var_q_symbol], "cur_qSplitSize"
|
| 898 |
+
)
|
| 899 |
+
if var_kv_symbol in kernel_group.args.sizevars:
|
| 900 |
+
output_code = output_code.replace(
|
| 901 |
+
kernel_group.args.sizevars[var_kv_symbol], "cur_kvSplitSize"
|
| 902 |
+
)
|
| 903 |
+
|
| 904 |
+
return output_code
|
| 905 |
+
|
| 906 |
+
@staticmethod
|
| 907 |
+
def add_choices(
|
| 908 |
+
choices,
|
| 909 |
+
input_nodes,
|
| 910 |
+
layout,
|
| 911 |
+
scale,
|
| 912 |
+
score_mod,
|
| 913 |
+
mask_mod,
|
| 914 |
+
kv_block_size,
|
| 915 |
+
q_block_size,
|
| 916 |
+
has_other_buffer,
|
| 917 |
+
no_full_kv_block,
|
| 918 |
+
fake_buffers,
|
| 919 |
+
len_score_other,
|
| 920 |
+
len_mask_other,
|
| 921 |
+
kernel_input_name_to_buffer,
|
| 922 |
+
block_vars,
|
| 923 |
+
):
|
| 924 |
+
def preprocessor(input_nodes, layout):
|
| 925 |
+
return input_nodes, layout
|
| 926 |
+
|
| 927 |
+
def postprocessor(output):
|
| 928 |
+
return output
|
| 929 |
+
|
| 930 |
+
template = DataProcessorTemplateWrapper(
|
| 931 |
+
CppFlexAttentionTemplate,
|
| 932 |
+
preprocessor,
|
| 933 |
+
postprocessor,
|
| 934 |
+
input_nodes=input_nodes,
|
| 935 |
+
layout=layout,
|
| 936 |
+
scale=scale,
|
| 937 |
+
score_mod=score_mod,
|
| 938 |
+
mask_mod=mask_mod,
|
| 939 |
+
kv_block_size=kv_block_size,
|
| 940 |
+
q_block_size=q_block_size,
|
| 941 |
+
has_other_buffer=has_other_buffer,
|
| 942 |
+
no_full_kv_block=no_full_kv_block,
|
| 943 |
+
fake_buffers=fake_buffers,
|
| 944 |
+
len_score_other=len_score_other,
|
| 945 |
+
len_mask_other=len_mask_other,
|
| 946 |
+
kernel_input_name_to_buffer=kernel_input_name_to_buffer,
|
| 947 |
+
block_vars=block_vars,
|
| 948 |
+
)
|
| 949 |
+
template.maybe_append_choice(choices)
|
| 950 |
+
return template
|
| 951 |
+
|
| 952 |
+
def apply_score_mod(self, score, b, h, q_idx, kv_idx):
|
| 953 |
+
return self.score_mod.graph_module(score, b, h, q_idx, kv_idx).item()
|
| 954 |
+
|
| 955 |
+
def render( # type: ignore[override,return]
|
| 956 |
+
self,
|
| 957 |
+
kernel,
|
| 958 |
+
template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
|
| 959 |
+
epilogue_nodes: Optional[list[ir.IRNode]] = None,
|
| 960 |
+
**kwargs,
|
| 961 |
+
) -> str:
|
| 962 |
+
if epilogue_nodes is not None and epilogue_nodes != []:
|
| 963 |
+
raise NotImplementedError(
|
| 964 |
+
"Unsupported for `epilogue_nodes` in CppFlexAttentionTemplate."
|
| 965 |
+
)
|
| 966 |
+
# Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
|
| 967 |
+
# -> (Batch x Q_seq_len x Num_heads x Dim_per_head)
|
| 968 |
+
# Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
|
| 969 |
+
# -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
|
| 970 |
+
# Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
|
| 971 |
+
# -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
|
| 972 |
+
|
| 973 |
+
query = kernel.permute(self.input_nodes[0], [0, 2, 1, 3])
|
| 974 |
+
key = kernel.permute(self.input_nodes[1], [0, 2, 1, 3])
|
| 975 |
+
value = kernel.permute(self.input_nodes[2], [0, 2, 1, 3])
|
| 976 |
+
self.accumulate_dtype = torch.float
|
| 977 |
+
self.input_dtype = query.layout.dtype
|
| 978 |
+
|
| 979 |
+
num_threads = parallel_num_threads()
|
| 980 |
+
buf_out = TensorBox.create(self.output_node)
|
| 981 |
+
if template_buffer_node is not None:
|
| 982 |
+
buf_out = template_buffer_node
|
| 983 |
+
options = dict(
|
| 984 |
+
query=query,
|
| 985 |
+
key=key,
|
| 986 |
+
value=value,
|
| 987 |
+
kv_num_blocks=self.input_nodes[3],
|
| 988 |
+
kv_indices=self.input_nodes[4],
|
| 989 |
+
full_kv_num_blocks=self.input_nodes[5]
|
| 990 |
+
if not self.no_full_kv_block
|
| 991 |
+
else None,
|
| 992 |
+
full_kv_indices=self.input_nodes[6] if not self.no_full_kv_block else None,
|
| 993 |
+
score_mod_other_buffers=self.score_mod_other_buffers,
|
| 994 |
+
mask_mod_other_buffers=self.mask_mod_other_buffers,
|
| 995 |
+
scale=self.scale,
|
| 996 |
+
has_full_kv_block=not self.no_full_kv_block,
|
| 997 |
+
accumulate_dtype=self.accumulate_dtype,
|
| 998 |
+
query_dtype=self.input_dtype,
|
| 999 |
+
kvBlockSize=self.kv_block_size,
|
| 1000 |
+
qBlockSize=self.q_block_size,
|
| 1001 |
+
template=self,
|
| 1002 |
+
output=buf_out,
|
| 1003 |
+
kernel=kernel,
|
| 1004 |
+
num_thread=num_threads,
|
| 1005 |
+
score_mod=self.score_mod,
|
| 1006 |
+
mask_mod=self.mask_mod,
|
| 1007 |
+
score_buf_name=self.score_buf_name,
|
| 1008 |
+
mask_buf_name=self.mask_buf_name,
|
| 1009 |
+
score_buf_idx=self.score_buf_idx,
|
| 1010 |
+
mask_buf_idx=self.mask_buf_idx,
|
| 1011 |
+
)
|
| 1012 |
+
with contextlib.ExitStack() as stack:
|
| 1013 |
+
for buf in self.fake_buffers:
|
| 1014 |
+
stack.enter_context(
|
| 1015 |
+
patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
|
| 1016 |
+
)
|
| 1017 |
+
return self._template_from_string(FLEX_ATTENTION_TEMPLATE).render(**options)
|
| 1018 |
+
|
| 1019 |
+
def codegen_softmax_fusion(self, kernel_name: str):
|
| 1020 |
+
# TODO: use inductor IR to rewrite those fusions
|
| 1021 |
+
return self._template_from_string(SOFTMAX_FUSIONS).render(
|
| 1022 |
+
dict(kernel_name=kernel_name)
|
| 1023 |
+
)
|
| 1024 |
+
|
| 1025 |
+
def codegen_brgemm_pack_function(self, kernel_name: str):
|
| 1026 |
+
# TODO: make them general for common bmm templates
|
| 1027 |
+
return self._template_from_string(BRGEMM_PACK_FUNCTIONS).render(
|
| 1028 |
+
dict(kernel_name=kernel_name)
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
def codegen_allocate_buffer(self, buffer_name: str, buffer_dtype, buffer_size):
|
| 1032 |
+
return self._template_from_string(ALLOCATE_BUFFER).render(
|
| 1033 |
+
dict(
|
| 1034 |
+
buffer_name=buffer_name,
|
| 1035 |
+
buffer_dtype=buffer_dtype,
|
| 1036 |
+
buffer_size=buffer_size,
|
| 1037 |
+
)
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
def micro_gemm_define(self, kernel_name: str):
|
| 1041 |
+
from torch._inductor.codegen.cpp_gemm_template import (
|
| 1042 |
+
CppTemplateKernel,
|
| 1043 |
+
parallel_num_threads,
|
| 1044 |
+
)
|
| 1045 |
+
from torch._inductor.codegen.cpp_micro_gemm import CppMicroGemmFP32Vec
|
| 1046 |
+
from torch._inductor.virtualized import V
|
| 1047 |
+
|
| 1048 |
+
micro_gemm_trans = CppMicroGemmFP32Vec(
|
| 1049 |
+
kernel_name + "_kernel_micro_gemm_transpose_b",
|
| 1050 |
+
self.input_dtype,
|
| 1051 |
+
self.input_dtype,
|
| 1052 |
+
self.accumulate_dtype,
|
| 1053 |
+
self.accumulate_dtype,
|
| 1054 |
+
GemmBlocking(1, 16, 1),
|
| 1055 |
+
1,
|
| 1056 |
+
True,
|
| 1057 |
+
True,
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
micro_gemm = CppMicroGemmFP32Vec(
|
| 1061 |
+
kernel_name + "_kernel_micro_gemm",
|
| 1062 |
+
self.input_dtype,
|
| 1063 |
+
self.input_dtype,
|
| 1064 |
+
self.accumulate_dtype,
|
| 1065 |
+
self.accumulate_dtype,
|
| 1066 |
+
GemmBlocking(1, 16, 1),
|
| 1067 |
+
1,
|
| 1068 |
+
True,
|
| 1069 |
+
False,
|
| 1070 |
+
)
|
| 1071 |
+
|
| 1072 |
+
with V.set_graph_handler(V.graph):
|
| 1073 |
+
kernel = CppTemplateKernel("cpp_micro_gemm", parallel_num_threads())
|
| 1074 |
+
code_trans = micro_gemm_trans.codegen_define(kernel)
|
| 1075 |
+
code = micro_gemm.codegen_define(kernel)
|
| 1076 |
+
return code + code_trans
|
| 1077 |
+
|
| 1078 |
+
def codegen_micro_gemm(self, kernel_name: str):
|
| 1079 |
+
micro_gemm = self.micro_gemm_define(kernel_name)
|
| 1080 |
+
GEMM_SOURCE_CODE = MICRO_GEMM_TEMPLATE.replace("GEMM_DEFINE", micro_gemm)
|
| 1081 |
+
return self._template_from_string(GEMM_SOURCE_CODE).render()
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_gemm_template.py
ADDED
|
@@ -0,0 +1,1777 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import contextlib
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
| 7 |
+
from unittest.mock import patch
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.utils
|
| 11 |
+
from torch.utils._ordered_set import OrderedSet
|
| 12 |
+
|
| 13 |
+
from ..._dynamo.utils import counters
|
| 14 |
+
from .. import config, ir, lowering as L
|
| 15 |
+
from ..kernel.mm_common import mm_args
|
| 16 |
+
from ..select_algorithm import DataProcessorTemplateWrapper
|
| 17 |
+
from ..utils import (
|
| 18 |
+
has_free_symbols,
|
| 19 |
+
is_same_mkldnn_tensor,
|
| 20 |
+
is_same_tensor,
|
| 21 |
+
parallel_num_threads,
|
| 22 |
+
)
|
| 23 |
+
from ..virtualized import ops, V
|
| 24 |
+
from .cpp import get_export_declaration
|
| 25 |
+
from .cpp_micro_gemm import (
|
| 26 |
+
CppMicroBrgemm,
|
| 27 |
+
CppMicroGemm,
|
| 28 |
+
CppMicroGemmAMX,
|
| 29 |
+
CppMicroGemmFP32Vec,
|
| 30 |
+
create_micro_gemm,
|
| 31 |
+
is_int8_woq_gemm_small_m_dim_corner_case,
|
| 32 |
+
LayoutType,
|
| 33 |
+
)
|
| 34 |
+
from .cpp_template import CppTemplate
|
| 35 |
+
from .cpp_template_kernel import CppTemplateKernel
|
| 36 |
+
from .cpp_utils import (
|
| 37 |
+
create_epilogue_with_attr,
|
| 38 |
+
DTYPE_TO_CPP,
|
| 39 |
+
GemmBlocking,
|
| 40 |
+
get_gemm_template_output_and_compute_dtype,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
log = logging.getLogger(__name__)
|
| 45 |
+
|
| 46 |
+
GEMM_TEMPLATE_INIT_BLOCKING_BASIC_BLOCK = r"""
|
| 47 |
+
constexpr int64_t num_threads = {{num_threads}};
|
| 48 |
+
constexpr int64_t N = {{N}};
|
| 49 |
+
constexpr int64_t K = {{K}};
|
| 50 |
+
constexpr int64_t Mr = {{micro_gemm.register_blocking.block_m}};
|
| 51 |
+
constexpr int64_t Nr = {{micro_gemm.register_blocking.block_n}};
|
| 52 |
+
constexpr int64_t Kr = {{micro_gemm.register_blocking.block_k}};
|
| 53 |
+
constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
|
| 54 |
+
constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
|
| 55 |
+
{%- if is_dynamic_M %}
|
| 56 |
+
const int64_t M = {{kernel.size(GemmOut, 0)}};
|
| 57 |
+
const int64_t Mr_blocks = (M + Mr - 1) / Mr;
|
| 58 |
+
{%- else %}
|
| 59 |
+
constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
|
| 60 |
+
constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
|
| 61 |
+
{%- endif %}
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
GEMM_TEMPLATE_INIT_BLOCKING_EXTENDED = r"""
|
| 65 |
+
{%- if is_dynamic_M %}
|
| 66 |
+
{%- if num_threads > 1 %}
|
| 67 |
+
int64_t Mt_blocks, Nt_blocks, Kt_blocks;
|
| 68 |
+
mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks);
|
| 69 |
+
{%- else %}
|
| 70 |
+
const auto Mt_blocks = Mr_blocks;
|
| 71 |
+
const auto Nt_blocks = Nr_blocks;
|
| 72 |
+
const auto Kt_blocks = Kr_blocks;
|
| 73 |
+
{%- endif %}
|
| 74 |
+
int64_t Mc_blocks, Nc_blocks, Kc_blocks;
|
| 75 |
+
uint32_t L1_cache_size = {{L1_cache_size}};
|
| 76 |
+
uint32_t L2_cache_size = {{L2_cache_size}};
|
| 77 |
+
mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>(
|
| 78 |
+
num_threads,
|
| 79 |
+
M,
|
| 80 |
+
N,
|
| 81 |
+
K,
|
| 82 |
+
Mr,
|
| 83 |
+
Nr,
|
| 84 |
+
Kr,
|
| 85 |
+
Mt_blocks,
|
| 86 |
+
Nt_blocks,
|
| 87 |
+
Kt_blocks,
|
| 88 |
+
Mc_blocks,
|
| 89 |
+
Nc_blocks,
|
| 90 |
+
Kc_blocks,
|
| 91 |
+
L1_cache_size,
|
| 92 |
+
L2_cache_size
|
| 93 |
+
);
|
| 94 |
+
const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
|
| 95 |
+
const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
|
| 96 |
+
const int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
|
| 97 |
+
const int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
|
| 98 |
+
const int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
|
| 99 |
+
{%- else %}
|
| 100 |
+
constexpr int64_t Mt_blocks = {{template.thread_blocking(num_threads).block_m}};
|
| 101 |
+
constexpr int64_t Nt_blocks = {{template.thread_blocking(num_threads).block_n}};
|
| 102 |
+
constexpr int64_t Kt_blocks = {{template.thread_blocking(num_threads).block_k}};
|
| 103 |
+
constexpr int64_t Mc_blocks = {{template.cache_blocking(num_threads).block_m}};
|
| 104 |
+
constexpr int64_t Nc_blocks = {{template.cache_blocking(num_threads).block_n}};
|
| 105 |
+
constexpr int64_t Kc_blocks = {{template.cache_blocking(num_threads).block_k}};
|
| 106 |
+
constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
|
| 107 |
+
constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
|
| 108 |
+
constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
|
| 109 |
+
constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
|
| 110 |
+
constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
|
| 111 |
+
{%- endif %}
|
| 112 |
+
{%- if is_woq_int4 %}
|
| 113 |
+
int64_t group_size = *q_group_size;
|
| 114 |
+
{%- endif %}
|
| 115 |
+
|
| 116 |
+
// make sure all partitions are assigned
|
| 117 |
+
{{kernel.assert_function}}(
|
| 118 |
+
Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= Mr_blocks * Nr_blocks * Kr_blocks,
|
| 119 |
+
"Not all partitions are assigned."
|
| 120 |
+
);
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
GEMM_TEMPLATE_MULTI_THREADS_PARAMS = r"""
|
| 124 |
+
const int tid = omp_get_thread_num();
|
| 125 |
+
const int64_t k_group_id = tid / num_Kt_blocks;
|
| 126 |
+
const int64_t k_slice_id = tid % num_Kt_blocks;
|
| 127 |
+
const int64_t n_group_id = k_group_id / num_Nt_blocks;
|
| 128 |
+
const int64_t n_slice_id = k_group_id % num_Nt_blocks;
|
| 129 |
+
const int64_t k_block_start = k_slice_id * Kt_blocks;
|
| 130 |
+
const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks);
|
| 131 |
+
const int64_t n_block_start = n_slice_id * Nt_blocks;
|
| 132 |
+
const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks);
|
| 133 |
+
const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks);
|
| 134 |
+
const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks);
|
| 135 |
+
const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks;
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
GEMM_TEMPLATE_SINGLE_THREAD_PARAMS = r"""
|
| 139 |
+
constexpr int tid = 0;
|
| 140 |
+
constexpr int64_t k_group_id = 0;
|
| 141 |
+
constexpr int64_t k_slice_id = 0;
|
| 142 |
+
constexpr int64_t n_group_id = 0;
|
| 143 |
+
constexpr int64_t n_slice_id = 0;
|
| 144 |
+
constexpr int64_t m_block_start = 0;
|
| 145 |
+
constexpr int64_t n_block_start = 0;
|
| 146 |
+
constexpr int64_t n_block_end = Nr_blocks;
|
| 147 |
+
constexpr int64_t k_block_start = 0;
|
| 148 |
+
constexpr int64_t k_block_end = Kr_blocks;
|
| 149 |
+
{%- if is_dynamic_M %}
|
| 150 |
+
const int64_t num_Mc_blocks_per_thread = num_Mc_blocks;
|
| 151 |
+
const int64_t m_block_end = Mr_blocks;
|
| 152 |
+
{%- else %}
|
| 153 |
+
constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks;
|
| 154 |
+
constexpr int64_t m_block_end = Mr_blocks;
|
| 155 |
+
{%- endif %}
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
GEMM_TEMPLATE_M_LOOP_PARAMS = r"""
|
| 159 |
+
const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread;
|
| 160 |
+
const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks;
|
| 161 |
+
const int64_t m_start = mc * Mr;
|
| 162 |
+
const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
|
| 163 |
+
const int64_t m_size = m_end - m_start;
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
GEMM_TEMPLATE_N_LOOP_PARAMS = r"""
|
| 167 |
+
const int64_t n_start = nc * Nr;
|
| 168 |
+
const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
|
| 169 |
+
const int64_t n_size = n_end - n_start;
|
| 170 |
+
// NB: assume we pad N, nc_block_end won't exceed padded N here.
|
| 171 |
+
const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
GEMM_TEMPLATE_MICROKERNEL_DEF = r"""
|
| 175 |
+
{{template.header().getvalue()}}
|
| 176 |
+
|
| 177 |
+
{{micro_gemm.codegen_define(kernel)}}
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
GEMM_TEMPLATE_STUB_DEF = r"""
|
| 181 |
+
{%- if x_scale is not none %}
|
| 182 |
+
{%- set kernel_args = {"X": X, "W": W, "inp": inp, "x_scale": x_scale, "x_zp": x_zp, "w_scale": w_scale, "w_zp": w_zp,} %}
|
| 183 |
+
{%- elif is_woq_int4 %}
|
| 184 |
+
{%- set kernel_args = {"X": X, "W": W, "q_group_size": q_group_size, "qscale_and_zeros": qscale_and_zeros} %}
|
| 185 |
+
{%- else %}
|
| 186 |
+
{%- set kernel_args = {"X": X, "W": W, "inp": inp} %}
|
| 187 |
+
{%- endif %}
|
| 188 |
+
|
| 189 |
+
extern "C" {{export_declaration}}
|
| 190 |
+
{{kernel.def_kernel(inputs=kernel_args, outputs={"Y": Y}, aliases=aliases)}}
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
GEMM_TEMPLATE = r"""
|
| 194 |
+
{{ template.codegen_gemm_stub_def() }}
|
| 195 |
+
{
|
| 196 |
+
{{ kernel.maybe_codegen_profile() }}
|
| 197 |
+
{{ template.codegen_blocks(
|
| 198 |
+
num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W
|
| 199 |
+
) }}
|
| 200 |
+
|
| 201 |
+
{%- if maybe_k_slicing %}
|
| 202 |
+
std::unique_ptr<std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[]> local_buf_ptrs;
|
| 203 |
+
if (num_Kt_blocks > 1) {
|
| 204 |
+
local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_Kt_blocks]);
|
| 205 |
+
}
|
| 206 |
+
{%- endif %}
|
| 207 |
+
|
| 208 |
+
{%- if num_threads > 1 %}
|
| 209 |
+
#pragma omp parallel num_threads({{num_threads}})
|
| 210 |
+
{
|
| 211 |
+
{{ template.codegen_multi_threads_params()|indent(8, false) }}
|
| 212 |
+
{%- else %}
|
| 213 |
+
{
|
| 214 |
+
{{ template.codegen_single_thread_params(is_dynamic_M)|indent(8, false) }}
|
| 215 |
+
{%- endif %}
|
| 216 |
+
{{ micro_gemm.codegen_init(kernel) }}
|
| 217 |
+
{%- if use_local_acc %}
|
| 218 |
+
{%- set acc_buf_name = "local_acc_buf" %}
|
| 219 |
+
{{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
|
| 220 |
+
{%- endif %}
|
| 221 |
+
for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
|
| 222 |
+
{{ template.codegen_m_loop_params()|indent(12, false) }}
|
| 223 |
+
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
|
| 224 |
+
{{ template.codegen_n_loop_params()|indent(16, false) }}
|
| 225 |
+
{%- if use_local_acc %}
|
| 226 |
+
{%- set acc = kernel.local_buffers[acc_buf_name] %}
|
| 227 |
+
{{ kernel.reinit_buffer_if_null(acc_buf_name) }}
|
| 228 |
+
{%- else %}
|
| 229 |
+
{%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_end")]) %}
|
| 230 |
+
{%- endif %}
|
| 231 |
+
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
|
| 232 |
+
int64_t k_start = kc * Kr;
|
| 233 |
+
int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
|
| 234 |
+
{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %}
|
| 235 |
+
for (int64_t nci = nc; nci < nc_block_end; nci++) {
|
| 236 |
+
{%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %}
|
| 237 |
+
{%- if template.should_block_weights and not is_woq_int4 %}
|
| 238 |
+
{%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %}
|
| 239 |
+
{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
|
| 240 |
+
{%- else %}
|
| 241 |
+
{%- if is_woq_int4 %}
|
| 242 |
+
{%- set tile_W = kernel.slice_nd(W, [("nci * Nr", "(nci + 1) * Nr"), ("k_start * Nr / 2", "k_end * Nr / 2")]) %}
|
| 243 |
+
{%- set tile_qparam = kernel.slice_nd(
|
| 244 |
+
qscale_and_zeros, [("k_start // group_size", "k_end // group_size"), ("nci * Nr", "(nci + 1) * Nr"), ()]) %}
|
| 245 |
+
{%- else %}
|
| 246 |
+
{%- set tile_W = kernel.slice_nd(W, [("k_start", "k_end"), ("n_start", "n_start + n_size")]) %}
|
| 247 |
+
{%- set tile_qparam = None %}
|
| 248 |
+
{%- endif %}
|
| 249 |
+
{%- endif %}
|
| 250 |
+
if (kc == k_block_start) {
|
| 251 |
+
{{ micro_gemm.codegen_call(kernel,
|
| 252 |
+
tile_X,
|
| 253 |
+
tile_W,
|
| 254 |
+
acc_slice,
|
| 255 |
+
accum=False,
|
| 256 |
+
qscale_and_zeros=tile_qparam)|indent(28, false)
|
| 257 |
+
}}
|
| 258 |
+
} else {
|
| 259 |
+
{{ micro_gemm.codegen_call(kernel,
|
| 260 |
+
tile_X,
|
| 261 |
+
tile_W,
|
| 262 |
+
acc_slice,
|
| 263 |
+
accum=True,
|
| 264 |
+
qscale_and_zeros=tile_qparam)|indent(28, false)
|
| 265 |
+
}}
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
{%- if maybe_k_slicing %}
|
| 270 |
+
if (num_Kt_blocks > 1) {
|
| 271 |
+
const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
|
| 272 |
+
local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + k_slice_id].reset(
|
| 273 |
+
{{ kernel.release_buffer(acc_buf_name) }});
|
| 274 |
+
} else
|
| 275 |
+
{%- endif %}
|
| 276 |
+
{
|
| 277 |
+
{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %}
|
| 278 |
+
{%- set tile_acc = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("0", "n_end - n_start")]) %}
|
| 279 |
+
{{ kernel.store_output(
|
| 280 |
+
tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
|
| 281 |
+
)|indent(20, false)
|
| 282 |
+
}}
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
{%- if maybe_k_slicing %}
|
| 287 |
+
if (num_Kt_blocks > 1) {
|
| 288 |
+
#pragma omp barrier
|
| 289 |
+
for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
|
| 290 |
+
// We slice M-dim and each thread in the k-slicing group works on a slice
|
| 291 |
+
const int64_t m_start_unsliced = mc * Mr;
|
| 292 |
+
const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
|
| 293 |
+
const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced;
|
| 294 |
+
const int64_t m_slice_size = (m_size_unsliced + num_Kt_blocks - 1) / num_Kt_blocks;
|
| 295 |
+
const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced);
|
| 296 |
+
const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced);
|
| 297 |
+
const int64_t m_size = m_end - m_start;
|
| 298 |
+
const int64_t m_offset = m_start - m_start_unsliced;
|
| 299 |
+
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
|
| 300 |
+
const int64_t n_start = nc * Nr;
|
| 301 |
+
const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
|
| 302 |
+
const int64_t n_size = n_end - n_start;
|
| 303 |
+
const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
|
| 304 |
+
auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks].get();
|
| 305 |
+
for (int64_t other_slice = 1; other_slice < num_Kt_blocks; other_slice++) {
|
| 306 |
+
auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + other_slice].get();
|
| 307 |
+
for (int64_t m = m_offset; m < m_offset + m_size; m++) {
|
| 308 |
+
#pragma omp simd
|
| 309 |
+
for (int64_t n = 0; n < n_size; n++) {
|
| 310 |
+
{{acc_buf_name}}[m*Nr + n] += other_acc[m*Nr + n];
|
| 311 |
+
}
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
{%- set tile_acc_m_slice = kernel.slice_nd(tile_acc, [("m_offset", "m_offset + m_end - m_start"), ()]) %}
|
| 315 |
+
{{ kernel.store_output(
|
| 316 |
+
tile_Y, tile_acc_m_slice, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
|
| 317 |
+
)|indent(20, false)
|
| 318 |
+
}}
|
| 319 |
+
}
|
| 320 |
+
}
|
| 321 |
+
}
|
| 322 |
+
{%- endif %}
|
| 323 |
+
{{ micro_gemm.codegen_finalize(kernel) }}
|
| 324 |
+
}
|
| 325 |
+
}
|
| 326 |
+
"""
|
| 327 |
+
|
| 328 |
+
SMALL_M_GEMM_TEMPLATE = r"""
|
| 329 |
+
{{ template.codegen_gemm_stub_def() }}
|
| 330 |
+
{
|
| 331 |
+
{{ kernel.maybe_codegen_profile() }}
|
| 332 |
+
{{ template.codegen_blocks(
|
| 333 |
+
num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W
|
| 334 |
+
) }}
|
| 335 |
+
# pragma omp parallel
|
| 336 |
+
{
|
| 337 |
+
#pragma omp for nowait
|
| 338 |
+
for (int64_t nr_block_id = 0; nr_block_id < Nr_blocks; nr_block_id++) {
|
| 339 |
+
// Handle one output M * Nr block in each thread
|
| 340 |
+
int64_t n_start = nr_block_id * Nr;
|
| 341 |
+
int64_t n_end = (nr_block_id + 1) * Nr;
|
| 342 |
+
{%- if use_local_acc %}
|
| 343 |
+
{%- set acc_buf_name = "local_acc_buf" %}
|
| 344 |
+
{{ kernel.define_stack_allocated_buffer(acc_buf_name, ["M", "Nr"], acc_buf_dtype) }}
|
| 345 |
+
{%- set acc = kernel.local_buffers[acc_buf_name] %}
|
| 346 |
+
{%- else %}
|
| 347 |
+
{%- set acc = kernel.slice_nd(GemmOut, [(0, "M"), ("n_start", "n_end")]) %}
|
| 348 |
+
{%- endif %}
|
| 349 |
+
for (int64_t kr_block_id = 0; kr_block_id < Kr_blocks; kr_block_id++) {
|
| 350 |
+
// this loop is not parallelized
|
| 351 |
+
int64_t k_start = kr_block_id * Kr;
|
| 352 |
+
int64_t k_end = std::min((kr_block_id + 1) * Kr, K);
|
| 353 |
+
{%- set tile_X = kernel.slice_nd(X, [(0, "M"), ("k_start", "k_end")]) %}
|
| 354 |
+
{%- set tile_W_3d = kernel.slice_nd(W, [("nr_block_id", "nr_block_id + 1"), ("k_start", "k_end"), ()]) %}
|
| 355 |
+
{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
|
| 356 |
+
if C10_UNLIKELY(kr_block_id == 0) {
|
| 357 |
+
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=False, prefetch=True)|indent(20, false) }}
|
| 358 |
+
} else if C10_UNLIKELY(k_end == K) {
|
| 359 |
+
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True, prefetch=False)|indent(20, false) }}
|
| 360 |
+
} else {
|
| 361 |
+
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True, prefetch=True)|indent(20, false) }}
|
| 362 |
+
}
|
| 363 |
+
}
|
| 364 |
+
{%- set tile_Y = kernel.slice_nd(Y_2d, [("0", "M"), ("n_start", "n_end")]) %}
|
| 365 |
+
{%- set tile_acc = kernel.slice_nd(acc, [("0", "M"), ("0", "n_end - n_start")]) %}
|
| 366 |
+
{{ kernel.store_output(
|
| 367 |
+
tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("0", "n_start"), reindexers=reindexers
|
| 368 |
+
)|indent(20, false) }}
|
| 369 |
+
}
|
| 370 |
+
}
|
| 371 |
+
}
|
| 372 |
+
"""
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def _is_int8_gemm(inputs):
|
| 376 |
+
return (
|
| 377 |
+
isinstance(inputs[0], ir.IRNode)
|
| 378 |
+
and inputs[0].get_dtype() in [torch.uint8, torch.int8]
|
| 379 |
+
) or (
|
| 380 |
+
isinstance(inputs[0], torch.Tensor)
|
| 381 |
+
and inputs[0].dtype in [torch.uint8, torch.int8]
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def get_padded_n(n, block_n):
|
| 386 |
+
return (n + block_n - 1) // block_n * block_n
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
_T = TypeVar("_T", ir.IRNode, torch.Tensor)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def transpose_w(W: _T, trans_w: bool) -> _T:
|
| 393 |
+
"""
|
| 394 |
+
Transpose W based on the trans_w flag.
|
| 395 |
+
"""
|
| 396 |
+
if isinstance(W, ir.IRNode):
|
| 397 |
+
if trans_w:
|
| 398 |
+
if not isinstance(W, ir.TensorBox):
|
| 399 |
+
W = ir.TensorBox(W)
|
| 400 |
+
W = L.permute(W, [1, 0])
|
| 401 |
+
else:
|
| 402 |
+
if trans_w:
|
| 403 |
+
assert isinstance(W, torch.Tensor)
|
| 404 |
+
W = W.transpose(0, 1)
|
| 405 |
+
return W
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]:
|
| 409 |
+
"""
|
| 410 |
+
Expand Bias to the same size of X.
|
| 411 |
+
"""
|
| 412 |
+
if B is not None:
|
| 413 |
+
if isinstance(B, ir.IRNode):
|
| 414 |
+
if not isinstance(B, ir.TensorBox):
|
| 415 |
+
B = ir.TensorBox(B)
|
| 416 |
+
assert hasattr(X, "get_size")
|
| 417 |
+
B = L.expand(B, (X.get_size()[0], B.get_size()[-1]))
|
| 418 |
+
else:
|
| 419 |
+
assert isinstance(B, torch.Tensor)
|
| 420 |
+
assert isinstance(X, torch.Tensor)
|
| 421 |
+
B = B.expand(X.shape[0], B.shape[-1])
|
| 422 |
+
return B
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def prune_tensors(input_nodes: list[ir.IRNode], new_input_nodes: list[ir.IRNode]):
|
| 426 |
+
"""
|
| 427 |
+
Prune unused tensors from `V.graph` since the GEMM Template use new packed weight.
|
| 428 |
+
"""
|
| 429 |
+
|
| 430 |
+
def share_storage(base_tensor: torch.Tensor, comp_tensor: torch.Tensor):
|
| 431 |
+
return base_tensor.is_mkldnn == comp_tensor.is_mkldnn and (
|
| 432 |
+
is_same_tensor(base_tensor, comp_tensor)
|
| 433 |
+
or is_same_mkldnn_tensor(base_tensor, comp_tensor)
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
def get_candidates(input_nodes, new_input_nodes):
|
| 437 |
+
# Only Constant Buffer like weight and bias might be changed in GEMM Template.
|
| 438 |
+
# The Inductor IR Node may changed, but still share the storage. For example:
|
| 439 |
+
# bias in bfloat16 case which only do the expand
|
| 440 |
+
return [
|
| 441 |
+
node
|
| 442 |
+
for node in input_nodes
|
| 443 |
+
if (
|
| 444 |
+
node not in new_input_nodes
|
| 445 |
+
and isinstance(node, (ir.TensorBox, ir.StorageBox))
|
| 446 |
+
and node.get_name() in V.graph.constants
|
| 447 |
+
and not any(
|
| 448 |
+
(
|
| 449 |
+
isinstance(new_node, (ir.TensorBox, ir.StorageBox))
|
| 450 |
+
and new_node.get_name() in V.graph.constants
|
| 451 |
+
and share_storage(
|
| 452 |
+
V.graph.constants[node.get_name()],
|
| 453 |
+
V.graph.constants[new_node.get_name()],
|
| 454 |
+
)
|
| 455 |
+
)
|
| 456 |
+
for new_node in new_input_nodes
|
| 457 |
+
)
|
| 458 |
+
)
|
| 459 |
+
]
|
| 460 |
+
|
| 461 |
+
for candidate_node in get_candidates(input_nodes, new_input_nodes):
|
| 462 |
+
# By using the new packed weight for the GEMM template, we can prune the
|
| 463 |
+
# old weight if it has no other users. This saves memory but makes the FX graph
|
| 464 |
+
# non-retraceable. To support retracing, we can add a repack node to the
|
| 465 |
+
# FX graph. For example:
|
| 466 |
+
# mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template
|
| 467 |
+
candidate_tensor_users = 0
|
| 468 |
+
candidate_tensor = V.graph.constants[candidate_node.get_name()]
|
| 469 |
+
for node in reversed(V.graph.graph.nodes):
|
| 470 |
+
# Case may happen when the candidate tensor is used by more than 1 get_attr node
|
| 471 |
+
# https://github.com/pytorch/pytorch/issues/134998
|
| 472 |
+
if node.op == "get_attr" and hasattr(
|
| 473 |
+
V.graph.module, node.target
|
| 474 |
+
): # candidate tensor might already be deleted
|
| 475 |
+
comp_tensor = getattr(V.graph.module, node.target)
|
| 476 |
+
if isinstance(comp_tensor, torch.Tensor) and share_storage(
|
| 477 |
+
candidate_tensor, comp_tensor
|
| 478 |
+
):
|
| 479 |
+
candidate_tensor_users += 1
|
| 480 |
+
|
| 481 |
+
for node in reversed(V.graph.graph.nodes):
|
| 482 |
+
# The get_attr node has only 1 user fx node
|
| 483 |
+
# The candidate tensor has been used by only 1 get_attr node
|
| 484 |
+
if (
|
| 485 |
+
node.op == "get_attr"
|
| 486 |
+
and node.target == candidate_node.get_name()
|
| 487 |
+
and len(node.users) == 1
|
| 488 |
+
and candidate_tensor_users == 1
|
| 489 |
+
):
|
| 490 |
+
del V.graph.constants[node.target]
|
| 491 |
+
delattr(V.graph.module, node.target)
|
| 492 |
+
delattr(V.graph.graph.owning_module, node.target)
|
| 493 |
+
counters["inductor"]["select_algorithm_weight_prune"] += 1
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def gen_2d_view_of_epilogue_buf(
|
| 497 |
+
Y: ir.Buffer,
|
| 498 |
+
template_buffer: ir.Buffer,
|
| 499 |
+
epilogue_nodes: list[ir.IRNode],
|
| 500 |
+
reindexers: list[Optional[Callable[[list[Any]], list[Any]]]],
|
| 501 |
+
default_reindexers: list[Optional[Callable[[list[Any]], list[Any]]]],
|
| 502 |
+
) -> tuple[
|
| 503 |
+
Union[ir.Buffer, ir.ReinterpretView],
|
| 504 |
+
list[Optional[Callable[[list[Any]], list[Any]]]],
|
| 505 |
+
]:
|
| 506 |
+
"""
|
| 507 |
+
The dimension and the indexing could be different between the GEMM output, i.e. `template_buffer`, which is
|
| 508 |
+
2D with MxN) and the output from the template after epilogues, i.e. `Y`. In the GEMM template code,
|
| 509 |
+
we are not aware of the dimension and the indexing of the epilogues and always work on 2D tiles according to
|
| 510 |
+
the indexing of the GEMM output.
|
| 511 |
+
In this function, we return a 2D buffer (`Y_2d`) according to GEMM output (reinterpreted from `Y` if needed) and
|
| 512 |
+
build a reindexer that converts the indexing of `Y` into `Y_2d`.
|
| 513 |
+
"""
|
| 514 |
+
Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
|
| 515 |
+
if (
|
| 516 |
+
Y.get_size() == template_buffer.get_size()
|
| 517 |
+
and Y.get_stride() == template_buffer.get_stride()
|
| 518 |
+
):
|
| 519 |
+
reindexers.extend(default_reindexers)
|
| 520 |
+
Y_2d = Y
|
| 521 |
+
else:
|
| 522 |
+
|
| 523 |
+
def get_reindexer(epilogue_node, default_reindexer=None):
|
| 524 |
+
# From template_buffer to epilogue_node_ordered (ordered by stride decreasingly, in dense format), for example:
|
| 525 |
+
# template_buffer:
|
| 526 |
+
# size (324, 512), stride (512, 1)
|
| 527 |
+
# epilogue_node_ordered (ordered by stride decreasingly, in dense format):
|
| 528 |
+
# size (1, 18, 18, 512), stride (165888, 9216, 512, 1)
|
| 529 |
+
stride_order = list(
|
| 530 |
+
ir.get_stride_order(
|
| 531 |
+
V.graph.sizevars.size_hints(epilogue_node.get_stride())
|
| 532 |
+
)
|
| 533 |
+
)
|
| 534 |
+
fill_order = ir.stride_order2fill_order(stride_order)
|
| 535 |
+
reversed_fill_order = list(reversed(fill_order))
|
| 536 |
+
size_with_stride_ordered_decreasingly = [
|
| 537 |
+
epilogue_node.get_size()[i] for i in reversed_fill_order
|
| 538 |
+
]
|
| 539 |
+
reshape_reindex = ir.View.dynamic_reshape_indexer(
|
| 540 |
+
size_with_stride_ordered_decreasingly,
|
| 541 |
+
template_buffer.get_size(),
|
| 542 |
+
)
|
| 543 |
+
if default_reindexer:
|
| 544 |
+
reshape_reindex = ir.fuse_reindexing(reshape_reindex, default_reindexer)
|
| 545 |
+
|
| 546 |
+
# From epilogue_node_ordered (ordered by stride decreasingly, in dense format) to epilogue_node, for example:
|
| 547 |
+
# epilogue_node_ordered (ordered by stride decreasingly, in dense format):
|
| 548 |
+
# size (1, 18, 18, 512), stride (165888, 9216, 512, 1)
|
| 549 |
+
# epilogue_node:
|
| 550 |
+
# size (1, 18, 18, 512), stride (165888, 1, 9216, 512)
|
| 551 |
+
from_stride_ordered_decreasingly_to_epilogue_node_order = [
|
| 552 |
+
(len(stride_order) - 1) - stride_order[i]
|
| 553 |
+
for i in range(len(stride_order))
|
| 554 |
+
]
|
| 555 |
+
stride_reindex = ir.same_reorder(
|
| 556 |
+
from_stride_ordered_decreasingly_to_epilogue_node_order
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex) # type: ignore[var-annotated]
|
| 560 |
+
return reindexer
|
| 561 |
+
|
| 562 |
+
if default_reindexers is None:
|
| 563 |
+
default_reindexers = [None] * len(epilogue_nodes)
|
| 564 |
+
new_reindexers = [
|
| 565 |
+
get_reindexer(epilogue_node, default_reindexer)
|
| 566 |
+
for epilogue_node, default_reindexer in zip(
|
| 567 |
+
epilogue_nodes, default_reindexers
|
| 568 |
+
)
|
| 569 |
+
]
|
| 570 |
+
reindexers.extend(new_reindexers)
|
| 571 |
+
if isinstance(Y, ir.BaseView):
|
| 572 |
+
storage = ir.StorageBox(Y.unwrap_view())
|
| 573 |
+
else:
|
| 574 |
+
assert isinstance(Y, ir.Buffer)
|
| 575 |
+
storage = ir.StorageBox(Y)
|
| 576 |
+
Y_2d = ir.ReinterpretView(data=storage, layout=template_buffer.get_layout())
|
| 577 |
+
return Y_2d, reindexers
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
class CppGemmTemplate(CppTemplate):
|
| 581 |
+
"""
|
| 582 |
+
GEMM Template for Inductor CPP Backend.
|
| 583 |
+
"""
|
| 584 |
+
|
| 585 |
+
def __init__(
|
| 586 |
+
self,
|
| 587 |
+
input_nodes,
|
| 588 |
+
layout: ir.Layout,
|
| 589 |
+
num_threads: int,
|
| 590 |
+
register_blocking: GemmBlocking,
|
| 591 |
+
beta=1,
|
| 592 |
+
alpha=1,
|
| 593 |
+
has_bias=False,
|
| 594 |
+
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
| 595 |
+
should_block_weights: bool = True,
|
| 596 |
+
name="packed_gemm",
|
| 597 |
+
) -> None:
|
| 598 |
+
assert layout.dtype in [torch.float, torch.bfloat16, torch.half, torch.uint8]
|
| 599 |
+
super().__init__(
|
| 600 |
+
name,
|
| 601 |
+
input_nodes,
|
| 602 |
+
layout,
|
| 603 |
+
num_threads,
|
| 604 |
+
epilogue_creator=epilogue_creator,
|
| 605 |
+
)
|
| 606 |
+
self.beta = beta
|
| 607 |
+
self.alpha = alpha
|
| 608 |
+
self.has_bias = has_bias
|
| 609 |
+
self.register_blocking = register_blocking
|
| 610 |
+
m, n = layout.size[-2:]
|
| 611 |
+
k = input_nodes[0].get_size()[-1]
|
| 612 |
+
self.m, self.n, self.k = m, n, k
|
| 613 |
+
self.padded_n = get_padded_n(n, self.register_blocking.block_n)
|
| 614 |
+
self.is_dynamic_M = has_free_symbols((m,))
|
| 615 |
+
self.should_block_weights = should_block_weights
|
| 616 |
+
self.thread_blocking = self.make_thread_blocking_cache()
|
| 617 |
+
self.cache_blocking = self.make_cache_blocking_cache()
|
| 618 |
+
|
| 619 |
+
def make_thread_blocking_cache(self):
|
| 620 |
+
cache = lru_cache()(self._thread_blocking)
|
| 621 |
+
|
| 622 |
+
def thread_blocking(num_threads: int) -> GemmBlocking:
|
| 623 |
+
return cache(num_threads)
|
| 624 |
+
|
| 625 |
+
return thread_blocking
|
| 626 |
+
|
| 627 |
+
def _thread_blocking(self, num_threads: int) -> GemmBlocking:
|
| 628 |
+
"""
|
| 629 |
+
NOTE [Thread blocking in Cpp GEMM]
|
| 630 |
+
We use simple heuristics to decide the thread blocking:
|
| 631 |
+
1. Make sure all threads are occupied as much as possible.
|
| 632 |
+
2. For (m, n) blocks, favor more square-sized thread blocks for better data reuse.
|
| 633 |
+
3. If (m, n) blocks cannot occupy all the threads, we consider k-slicing.
|
| 634 |
+
TODO(jgong5): allow tuning various blocking options
|
| 635 |
+
"""
|
| 636 |
+
|
| 637 |
+
def get_factors(number):
|
| 638 |
+
factors = []
|
| 639 |
+
for i in range(int(number**0.5), 0, -1):
|
| 640 |
+
if number % i == 0:
|
| 641 |
+
factors.append(number // i)
|
| 642 |
+
factors.append(i)
|
| 643 |
+
return factors
|
| 644 |
+
|
| 645 |
+
def get_blocking(m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks):
|
| 646 |
+
thread_block_k = math.ceil(k_blocks / k_factor)
|
| 647 |
+
thread_block_n = math.ceil(n_blocks / n_factor)
|
| 648 |
+
thread_block_m = math.ceil(m_blocks / m_factor)
|
| 649 |
+
return GemmBlocking(thread_block_m, thread_block_n, thread_block_k)
|
| 650 |
+
|
| 651 |
+
assert not self.is_dynamic_M, (
|
| 652 |
+
"Unable to determine thread blocking for dynamic M."
|
| 653 |
+
)
|
| 654 |
+
register_blocking = self.register_blocking
|
| 655 |
+
m_blocks = math.ceil(self.m / register_blocking.block_m)
|
| 656 |
+
n_blocks = math.ceil(self.n / register_blocking.block_n)
|
| 657 |
+
k_blocks = math.ceil(self.k / register_blocking.block_k)
|
| 658 |
+
factors = get_factors(num_threads)
|
| 659 |
+
assert len(factors) > 0
|
| 660 |
+
|
| 661 |
+
if config.cpp.gemm_thread_factors is not None:
|
| 662 |
+
factors = [int(i) for i in config.cpp.gemm_thread_factors.split(",")]
|
| 663 |
+
assert len(factors) == 3
|
| 664 |
+
assert math.prod(factors) == self.num_threads
|
| 665 |
+
return get_blocking(
|
| 666 |
+
factors[0], factors[1], factors[2], m_blocks, n_blocks, k_blocks
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
# we favor square-sized thread blocks for good data reuse
|
| 670 |
+
def get_better_blocking(blocking, best_blocking):
|
| 671 |
+
if best_blocking is None:
|
| 672 |
+
best_blocking = blocking
|
| 673 |
+
else:
|
| 674 |
+
block_m_size = blocking.block_m * register_blocking.block_m
|
| 675 |
+
block_n_size = blocking.block_n * register_blocking.block_n
|
| 676 |
+
best_block_m_size = best_blocking.block_m * register_blocking.block_m
|
| 677 |
+
best_block_n_size = best_blocking.block_n * register_blocking.block_n
|
| 678 |
+
if blocking.block_k > best_blocking.block_k:
|
| 679 |
+
best_blocking = blocking
|
| 680 |
+
elif (
|
| 681 |
+
blocking.block_k == best_blocking.block_k
|
| 682 |
+
and block_m_size + block_n_size
|
| 683 |
+
< best_block_m_size + best_block_n_size
|
| 684 |
+
):
|
| 685 |
+
best_blocking = blocking
|
| 686 |
+
return best_blocking
|
| 687 |
+
|
| 688 |
+
best_blocking = None
|
| 689 |
+
# check if we can have a thread-blocking to occupy all threads without k-slicing
|
| 690 |
+
for n_factor in factors:
|
| 691 |
+
m_factor = num_threads // n_factor
|
| 692 |
+
if n_blocks >= n_factor and m_blocks >= m_factor:
|
| 693 |
+
blocking = get_blocking(
|
| 694 |
+
m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
|
| 695 |
+
)
|
| 696 |
+
best_blocking = get_better_blocking(blocking, best_blocking)
|
| 697 |
+
|
| 698 |
+
if best_blocking is None:
|
| 699 |
+
for k_factor in factors:
|
| 700 |
+
if k_blocks >= k_factor and (
|
| 701 |
+
config.cpp.gemm_max_k_slices == 0
|
| 702 |
+
or k_factor <= config.cpp.gemm_max_k_slices
|
| 703 |
+
):
|
| 704 |
+
n_factors = get_factors(num_threads // k_factor)
|
| 705 |
+
for n_factor in n_factors:
|
| 706 |
+
m_factor = (num_threads // k_factor) // n_factor
|
| 707 |
+
if n_blocks >= n_factor and m_blocks >= m_factor:
|
| 708 |
+
blocking = get_blocking(
|
| 709 |
+
m_factor,
|
| 710 |
+
n_factor,
|
| 711 |
+
k_factor,
|
| 712 |
+
m_blocks,
|
| 713 |
+
n_blocks,
|
| 714 |
+
k_blocks,
|
| 715 |
+
)
|
| 716 |
+
best_blocking = get_better_blocking(blocking, best_blocking)
|
| 717 |
+
|
| 718 |
+
if best_blocking is None:
|
| 719 |
+
for n_factor in factors:
|
| 720 |
+
m_factor = num_threads // n_factor
|
| 721 |
+
if n_blocks >= n_factor or m_blocks >= m_factor:
|
| 722 |
+
blocking = get_blocking(
|
| 723 |
+
m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
|
| 724 |
+
)
|
| 725 |
+
best_blocking = get_better_blocking(blocking, best_blocking)
|
| 726 |
+
|
| 727 |
+
assert best_blocking is not None
|
| 728 |
+
return best_blocking
|
| 729 |
+
|
| 730 |
+
def make_cache_blocking_cache(self):
|
| 731 |
+
cache = lru_cache()(self._cache_blocking)
|
| 732 |
+
|
| 733 |
+
def cache_blocking(num_threads: int) -> GemmBlocking:
|
| 734 |
+
return cache(num_threads)
|
| 735 |
+
|
| 736 |
+
return cache_blocking
|
| 737 |
+
|
| 738 |
+
def _cache_blocking(self, num_threads: int) -> GemmBlocking:
|
| 739 |
+
def get_cache_blocking(register_blocking, thread_blocking):
|
| 740 |
+
Mr = register_blocking.block_m
|
| 741 |
+
Nr = register_blocking.block_n
|
| 742 |
+
Kr = register_blocking.block_k
|
| 743 |
+
|
| 744 |
+
Mt_blocks = thread_blocking.block_m
|
| 745 |
+
Nt_blocks = thread_blocking.block_n
|
| 746 |
+
Kt_blocks = thread_blocking.block_k
|
| 747 |
+
|
| 748 |
+
if config.cpp.gemm_cache_blocking is not None:
|
| 749 |
+
blockings = [int(i) for i in config.cpp.gemm_cache_blocking.split(",")]
|
| 750 |
+
assert len(blockings) == 3
|
| 751 |
+
Mc_blocks, Nc_blocks, Kc_blocks = blockings
|
| 752 |
+
return (
|
| 753 |
+
min(Mc_blocks, Mt_blocks),
|
| 754 |
+
min(Nc_blocks, Nt_blocks),
|
| 755 |
+
min(Kc_blocks, Kt_blocks),
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
# The ratios below are empirically determined to decide
|
| 759 |
+
# the effective sizes of L1 and L2.
|
| 760 |
+
# TODO: tune the factor here
|
| 761 |
+
L1_limit_factor = 0.8
|
| 762 |
+
L2_limit_factor = 0.5
|
| 763 |
+
|
| 764 |
+
L1_cache_size = (
|
| 765 |
+
torch._C._cpu._L1d_cache_size()
|
| 766 |
+
) # per core cache size in Bytes
|
| 767 |
+
assert L1_cache_size > 0, (
|
| 768 |
+
f"Expect L1_cache_size > 0 but got {L1_cache_size}"
|
| 769 |
+
)
|
| 770 |
+
L1 = L1_cache_size * L1_limit_factor
|
| 771 |
+
|
| 772 |
+
L2_cache_size = (
|
| 773 |
+
torch._C._cpu._L2_cache_size()
|
| 774 |
+
) # per core cache size in Bytes
|
| 775 |
+
assert L2_cache_size > 0, (
|
| 776 |
+
f"Expect L2_cache_size > 0 but got {L2_cache_size}"
|
| 777 |
+
)
|
| 778 |
+
L2 = L2_cache_size * L2_limit_factor
|
| 779 |
+
|
| 780 |
+
def get_num_byte(dtype):
|
| 781 |
+
return torch.tensor([], dtype=dtype).element_size()
|
| 782 |
+
|
| 783 |
+
dtype_A = self.input_nodes[0].get_dtype()
|
| 784 |
+
dtype_B = self.input_nodes[1].get_dtype()
|
| 785 |
+
num_byte_A = get_num_byte(dtype_A)
|
| 786 |
+
num_byte_B = get_num_byte(dtype_B)
|
| 787 |
+
if dtype_A is torch.bfloat16 and dtype_B is torch.int8 and Kr != 1:
|
| 788 |
+
# We will cache dequantized weights (BF16) in L1D for AMX micro-kernel.
|
| 789 |
+
# In this case, the choice of the micro-kernel being used can't be decoupled from
|
| 790 |
+
# the cache blocking.
|
| 791 |
+
# TODO: Decouple the choice of micro-kernel from cache blocking
|
| 792 |
+
num_byte_B *= num_byte_A
|
| 793 |
+
|
| 794 |
+
# NOTE [CPP GEMM Cache Blocking Algorithm]
|
| 795 |
+
# Our overall strategy is to
|
| 796 |
+
# 1) Make cache blocks of B L1-reside and reused by multiple rows of A, i.e. Mc.
|
| 797 |
+
# Here, B is Kc x Nr where Nr is a single register block. We use L1 size to
|
| 798 |
+
# decide Kc. We want to make Mc large enough to better reuse B.
|
| 799 |
+
# 2) Make cache blocks of A L2-reside, which would limit Mc. We want to reuse A
|
| 800 |
+
# along N, where we have two sub-strategies (see notes below) to decide Mc and Nc.
|
| 801 |
+
|
| 802 |
+
# Step 1: Decide Kc assuming B block is L1-reside.
|
| 803 |
+
size_cache_B = Kr * Kt_blocks * Nr * num_byte_B
|
| 804 |
+
|
| 805 |
+
Kc_blocks = Kt_blocks
|
| 806 |
+
if size_cache_B > L1:
|
| 807 |
+
Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B))
|
| 808 |
+
|
| 809 |
+
if (
|
| 810 |
+
config.cpp.use_small_dequant_buffer
|
| 811 |
+
and dtype_A is torch.bfloat16
|
| 812 |
+
and dtype_B is torch.uint8
|
| 813 |
+
and Mt_blocks == 1
|
| 814 |
+
):
|
| 815 |
+
# Make a small dequant_B buffer for woq int4 [q_group_size, Nr]
|
| 816 |
+
# Since when Mt_blocks == 1, L1-reside B block can't be reused by A.
|
| 817 |
+
if Kc_blocks * Kr >= self.q_group_size():
|
| 818 |
+
Kc_blocks = self.q_group_size() // Kr
|
| 819 |
+
|
| 820 |
+
# Step 2: Decide Mc assuming A block is L2-reside.
|
| 821 |
+
min_Mc_ratio = 2 # TODO(jgong5): something to tune?
|
| 822 |
+
min_Mc_blocks = math.ceil(min_Mc_ratio * Mr / Nr)
|
| 823 |
+
assert min_Mc_blocks >= 1
|
| 824 |
+
Kt_bytes = Kt_blocks * Kr * num_byte_A
|
| 825 |
+
if min_Mc_blocks * Mr * Kt_bytes < L2:
|
| 826 |
+
# Strategy 1: A (Mc x Kt) resides in L2 and reused by all Nt
|
| 827 |
+
# when Nc_blocks is kept 1. Mc should be large enough (>= min_Mc_blocks)
|
| 828 |
+
# to reuse B (Kc x Nr) in L1. This makes C (Mc x Nr) small enough to reside
|
| 829 |
+
# in L1.
|
| 830 |
+
Mc_blocks = min(Mt_blocks, math.floor(L2 / (Mr * Kt_bytes)))
|
| 831 |
+
Nc_blocks = 1
|
| 832 |
+
else:
|
| 833 |
+
# Strategy 2: Kt is too large to hold A (Mc x Kt) in L2, we reuse
|
| 834 |
+
# A (Mc x Kc) in L2 by B (Kc x Nc). C (Mc x Nc) resides in L2.
|
| 835 |
+
Mc_blocks = Mt_blocks
|
| 836 |
+
Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks)
|
| 837 |
+
Nc_bytes = Nc_blocks * Nr * 4 # assume C or acc is float32/int32
|
| 838 |
+
Kc_bytes = Kc_blocks * Kr * num_byte_A
|
| 839 |
+
if Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2:
|
| 840 |
+
# The following is the solution for 4*Mc*Nc + Mc*Kc_bytes = L2,
|
| 841 |
+
# assuming Mc == Nc for good data reuse.
|
| 842 |
+
M_max = (math.sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8
|
| 843 |
+
if M_max < Mc_blocks * Mr:
|
| 844 |
+
Mc_blocks = math.floor(M_max / Mr)
|
| 845 |
+
Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks)
|
| 846 |
+
|
| 847 |
+
return Mc_blocks, Nc_blocks, Kc_blocks
|
| 848 |
+
|
| 849 |
+
assert not self.is_dynamic_M, (
|
| 850 |
+
"Unable to determine cache blocking for dynamic M."
|
| 851 |
+
)
|
| 852 |
+
register_blocking = self.register_blocking
|
| 853 |
+
thread_blocking = self.thread_blocking(num_threads)
|
| 854 |
+
|
| 855 |
+
return GemmBlocking(*get_cache_blocking(register_blocking, thread_blocking))
|
| 856 |
+
|
| 857 |
+
def log_blockings(self):
|
| 858 |
+
log.debug(f"Register blocking: {self.register_blocking}") # noqa: G004
|
| 859 |
+
if self.is_dynamic_M:
|
| 860 |
+
# thread and cache blockings are determined at runtime for dynamic shapes
|
| 861 |
+
return
|
| 862 |
+
log.debug(
|
| 863 |
+
f"Cache blocking: {self.cache_blocking(self.num_threads)}" # noqa: G004
|
| 864 |
+
)
|
| 865 |
+
thread_blocking = self.thread_blocking(self.num_threads)
|
| 866 |
+
log.debug(f"Thread blocking: {thread_blocking}") # noqa: G004
|
| 867 |
+
|
| 868 |
+
def get_occupancy():
|
| 869 |
+
m_blocks = math.ceil(self.m / self.register_blocking.block_m)
|
| 870 |
+
n_blocks = math.ceil(self.n / self.register_blocking.block_n)
|
| 871 |
+
k_blocks = math.ceil(self.k / self.register_blocking.block_k)
|
| 872 |
+
m = math.ceil(m_blocks / thread_blocking.block_m)
|
| 873 |
+
n = math.ceil(n_blocks / thread_blocking.block_n)
|
| 874 |
+
k = math.ceil(k_blocks / thread_blocking.block_k)
|
| 875 |
+
return (m, n, k)
|
| 876 |
+
|
| 877 |
+
log.debug(
|
| 878 |
+
f"Number of threads: {self.num_threads}, occupancy: {get_occupancy()}" # noqa: G004
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
def maybe_k_slicing(self):
|
| 882 |
+
if self.num_threads == 1:
|
| 883 |
+
return False
|
| 884 |
+
if self.is_dynamic_M:
|
| 885 |
+
# TODO(jgong5): perhaps use size hint to decide?
|
| 886 |
+
return True
|
| 887 |
+
register_blocking = self.register_blocking
|
| 888 |
+
k_blocks = math.ceil(self.k / register_blocking.block_k)
|
| 889 |
+
thread_blocking = self.thread_blocking(self.num_threads)
|
| 890 |
+
return k_blocks > thread_blocking.block_k
|
| 891 |
+
|
| 892 |
+
@classmethod
|
| 893 |
+
def add_choices(
|
| 894 |
+
cls,
|
| 895 |
+
choices,
|
| 896 |
+
layout,
|
| 897 |
+
input_nodes,
|
| 898 |
+
beta=1,
|
| 899 |
+
alpha=1,
|
| 900 |
+
has_bias=False,
|
| 901 |
+
trans_w=False,
|
| 902 |
+
input_indices=None,
|
| 903 |
+
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
| 904 |
+
act_mapping: Optional[dict[int, ir.IRNode]] = None,
|
| 905 |
+
):
|
| 906 |
+
"""
|
| 907 |
+
Add choices for the GEMM template.
|
| 908 |
+
"""
|
| 909 |
+
# Fast path to save the epilogue calculation when x_scale/x_zp/w_scale are constant
|
| 910 |
+
use_int8_fast_compensation_path = _is_int8_gemm(input_nodes) and all(
|
| 911 |
+
(
|
| 912 |
+
isinstance(input_nodes[idx], ir.TensorBox)
|
| 913 |
+
and isinstance(input_nodes[idx].data.data, ir.ConstantBuffer)
|
| 914 |
+
)
|
| 915 |
+
for idx in [1, 2, 4]
|
| 916 |
+
)
|
| 917 |
+
|
| 918 |
+
if input_indices is None:
|
| 919 |
+
input_indices = list(range(len(input_nodes)))
|
| 920 |
+
only_one_input = (
|
| 921 |
+
input_nodes[0] == input_nodes[1] if len(input_nodes) > 1 else False
|
| 922 |
+
)
|
| 923 |
+
|
| 924 |
+
def reorder_and_filter(inputs, layout_or_out):
|
| 925 |
+
if has_bias:
|
| 926 |
+
assert len(input_indices) >= 3
|
| 927 |
+
# Assume the input order is [inp, x, w] and we reorder it to [x, w, inp]
|
| 928 |
+
inp_idx = input_indices[0]
|
| 929 |
+
x_idx = input_indices[1]
|
| 930 |
+
w_idx = input_indices[2]
|
| 931 |
+
return [
|
| 932 |
+
inputs[x_idx],
|
| 933 |
+
inputs[w_idx],
|
| 934 |
+
inputs[inp_idx],
|
| 935 |
+
*[inputs[idx] for idx in input_indices[3:]],
|
| 936 |
+
], layout_or_out
|
| 937 |
+
elif len(inputs) >= len(input_indices):
|
| 938 |
+
assert len(input_indices) >= 2
|
| 939 |
+
return [inputs[idx] for idx in input_indices], layout_or_out
|
| 940 |
+
else:
|
| 941 |
+
# For when input is used for x and w, i.e. X@X.T or similar
|
| 942 |
+
# Assumes the first input is the only input
|
| 943 |
+
assert len(inputs) == 1
|
| 944 |
+
return [inputs[0]] * len(input_indices), layout_or_out
|
| 945 |
+
|
| 946 |
+
new_inputs, new_layout = reorder_and_filter(input_nodes, layout)
|
| 947 |
+
is_mkldnn_wgt = (
|
| 948 |
+
new_inputs[1].get_name() in V.graph.constants
|
| 949 |
+
and V.graph.constants[new_inputs[1].get_name()].is_mkldnn
|
| 950 |
+
)
|
| 951 |
+
if is_mkldnn_wgt:
|
| 952 |
+
# It shouldn't happen as viewing an mkldnn tensor, we can extend the
|
| 953 |
+
# implementation if it does.
|
| 954 |
+
assert not isinstance(new_inputs[1], ir.BaseView)
|
| 955 |
+
# Note that the layout of MKLDNN Tensor is with the wrong stride
|
| 956 |
+
view_size = new_inputs[1].layout.size
|
| 957 |
+
view_stride = new_inputs[1].layout.stride
|
| 958 |
+
view_offset = new_inputs[1].layout.offset
|
| 959 |
+
|
| 960 |
+
def maybe_to_dense(inputs, layout_or_out):
|
| 961 |
+
new_inputs = list(inputs)
|
| 962 |
+
if isinstance(inputs[1], torch.Tensor):
|
| 963 |
+
W = inputs[1]
|
| 964 |
+
new_inputs[1] = W.to_dense() if W.is_mkldnn else W
|
| 965 |
+
return new_inputs, layout_or_out
|
| 966 |
+
|
| 967 |
+
def normalize_shapes(inputs, layout_or_out):
|
| 968 |
+
new_inputs = list(inputs)
|
| 969 |
+
if not is_mkldnn_wgt and isinstance(new_inputs[1], torch.Tensor):
|
| 970 |
+
if has_free_symbols(view_size):
|
| 971 |
+
# If batch size B is dynamic, we need to set the batch size and possibly stride
|
| 972 |
+
assert not has_free_symbols(view_size[1:])
|
| 973 |
+
view_size[:] = V.graph.sizevars.size_hints(view_size)
|
| 974 |
+
view_stride[:] = V.graph.sizevars.size_hints(view_stride)
|
| 975 |
+
# With the assumptation that W is the storage of unwrap view
|
| 976 |
+
# thus view it back here
|
| 977 |
+
new_inputs[1] = new_inputs[1].as_strided(
|
| 978 |
+
view_size, view_stride, view_offset
|
| 979 |
+
)
|
| 980 |
+
|
| 981 |
+
if not trans_w:
|
| 982 |
+
return new_inputs, layout_or_out
|
| 983 |
+
X = new_inputs[0]
|
| 984 |
+
W = new_inputs[1]
|
| 985 |
+
B = new_inputs[2] if has_bias else None
|
| 986 |
+
W = transpose_w(W, trans_w)
|
| 987 |
+
B = expand_bias(B, X) # type:ignore[arg-type]
|
| 988 |
+
new_inputs[1] = W
|
| 989 |
+
if B is not None:
|
| 990 |
+
new_inputs[2] = B
|
| 991 |
+
return new_inputs, layout_or_out
|
| 992 |
+
|
| 993 |
+
# TODO(jgong5): decide proper number of threads per problem size
|
| 994 |
+
num_threads = parallel_num_threads()
|
| 995 |
+
new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout))
|
| 996 |
+
m, n, k, *_ = mm_args(
|
| 997 |
+
new_inputs[0],
|
| 998 |
+
new_inputs[1],
|
| 999 |
+
mat2_transposed=cls.is_woq_int4(),
|
| 1000 |
+
use_4x2_dim=cls.is_woq_int4(),
|
| 1001 |
+
)
|
| 1002 |
+
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
|
| 1003 |
+
new_inputs[0].get_dtype()
|
| 1004 |
+
)
|
| 1005 |
+
micro_gemm = create_micro_gemm(
|
| 1006 |
+
"micro_gemm",
|
| 1007 |
+
m,
|
| 1008 |
+
n,
|
| 1009 |
+
k,
|
| 1010 |
+
input_dtype=new_inputs[0].get_dtype(),
|
| 1011 |
+
input2_dtype=new_inputs[1].get_dtype(),
|
| 1012 |
+
output_dtype=output_dtype,
|
| 1013 |
+
compute_dtype=compute_dtype,
|
| 1014 |
+
alpha=alpha,
|
| 1015 |
+
num_threads=num_threads,
|
| 1016 |
+
use_ref=not cls.is_woq_int4(),
|
| 1017 |
+
q_group_size=cls.q_group_size(),
|
| 1018 |
+
)
|
| 1019 |
+
assert micro_gemm is not None
|
| 1020 |
+
pre_block_weights = cls.check_if_block_weight(new_inputs[1], micro_gemm)
|
| 1021 |
+
micro_gemm.use_local_vnni_blocking(not pre_block_weights)
|
| 1022 |
+
|
| 1023 |
+
def preprocessor(inputs, layout):
|
| 1024 |
+
new_inputs, new_layout = normalize_shapes(
|
| 1025 |
+
*maybe_to_dense(*reorder_and_filter(inputs, layout))
|
| 1026 |
+
)
|
| 1027 |
+
if only_one_input and isinstance(new_inputs[0], torch.Tensor):
|
| 1028 |
+
return new_inputs[1:], new_layout
|
| 1029 |
+
return cls.prep_weight(
|
| 1030 |
+
new_inputs,
|
| 1031 |
+
new_layout,
|
| 1032 |
+
micro_gemm,
|
| 1033 |
+
pre_block_weights,
|
| 1034 |
+
use_int8_fast_compensation_path,
|
| 1035 |
+
)
|
| 1036 |
+
|
| 1037 |
+
def postprocessor(output):
|
| 1038 |
+
if isinstance(output, ir.TensorBox):
|
| 1039 |
+
# prepack the weight as input to the template buffer
|
| 1040 |
+
template_buffer = ir.InputsKernel.unwrap_storage_for_input(output)
|
| 1041 |
+
assert isinstance(template_buffer, ir.CppTemplateBuffer)
|
| 1042 |
+
new_input_nodes, _ = reorder_and_filter(input_nodes, layout)
|
| 1043 |
+
|
| 1044 |
+
W_node = new_input_nodes[1]
|
| 1045 |
+
if W_node.get_name() not in V.graph.constants:
|
| 1046 |
+
return output
|
| 1047 |
+
W = V.graph.constants[W_node.get_name()]
|
| 1048 |
+
new_input_nodes[1] = W
|
| 1049 |
+
new_input_nodes, new_layout = normalize_shapes(
|
| 1050 |
+
*maybe_to_dense(new_input_nodes, layout)
|
| 1051 |
+
)
|
| 1052 |
+
new_input_nodes, _ = cls.prep_weight(
|
| 1053 |
+
new_input_nodes,
|
| 1054 |
+
new_layout,
|
| 1055 |
+
micro_gemm,
|
| 1056 |
+
pre_block_weights,
|
| 1057 |
+
use_int8_fast_compensation_path,
|
| 1058 |
+
skip_int8_compensation=True,
|
| 1059 |
+
)
|
| 1060 |
+
W_packed = new_input_nodes[1]
|
| 1061 |
+
W_packed_constant = V.graph.add_tensor_constant(W_packed)
|
| 1062 |
+
new_input_nodes[1] = W_packed_constant
|
| 1063 |
+
|
| 1064 |
+
# Prune unused tensors
|
| 1065 |
+
prune_tensors(input_nodes, new_input_nodes)
|
| 1066 |
+
|
| 1067 |
+
template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input(
|
| 1068 |
+
W_packed_constant
|
| 1069 |
+
)
|
| 1070 |
+
return output
|
| 1071 |
+
|
| 1072 |
+
template = DataProcessorTemplateWrapper(
|
| 1073 |
+
cls,
|
| 1074 |
+
preprocessor,
|
| 1075 |
+
postprocessor,
|
| 1076 |
+
input_nodes=input_nodes,
|
| 1077 |
+
layout=layout,
|
| 1078 |
+
num_threads=num_threads,
|
| 1079 |
+
register_blocking=micro_gemm.register_blocking,
|
| 1080 |
+
beta=beta,
|
| 1081 |
+
alpha=alpha,
|
| 1082 |
+
has_bias=has_bias,
|
| 1083 |
+
epilogue_creator=epilogue_creator,
|
| 1084 |
+
should_block_weights=pre_block_weights,
|
| 1085 |
+
name=micro_gemm.__class__.__name__,
|
| 1086 |
+
)
|
| 1087 |
+
template.maybe_append_choice(choices)
|
| 1088 |
+
return template
|
| 1089 |
+
|
| 1090 |
+
@staticmethod
|
| 1091 |
+
def get_padded_size(n, block_n, k, should_block_weight):
|
| 1092 |
+
padded_n = get_padded_n(n, block_n)
|
| 1093 |
+
# We assume that all GEMM weight tensors should be blocked and padded
|
| 1094 |
+
new_size = [padded_n // block_n, k, block_n]
|
| 1095 |
+
return new_size, padded_n
|
| 1096 |
+
|
| 1097 |
+
@classmethod
|
| 1098 |
+
def prep_weight(
|
| 1099 |
+
cls,
|
| 1100 |
+
inputs,
|
| 1101 |
+
layout: ir.Layout,
|
| 1102 |
+
micro_gemm: CppMicroGemm,
|
| 1103 |
+
should_block_weight: bool,
|
| 1104 |
+
use_int8_fast_compensation_path: bool = False,
|
| 1105 |
+
skip_int8_compensation: bool = False,
|
| 1106 |
+
):
|
| 1107 |
+
"""
|
| 1108 |
+
NOTE Weight prep consists of 2 separate steps:
|
| 1109 |
+
1. Blocking the weight tensor into a 3D shape: [n//block_n, k, block_n]
|
| 1110 |
+
This is always done if the weight tensor is constant, i.e. for all GEMM and some BMM.
|
| 1111 |
+
For BMM, we also block non-contiguous weight tensors, since they would be reshaped anyway.
|
| 1112 |
+
This assumes that blocked, contiguous weights will be more efficient for the GEMM kernel,
|
| 1113 |
+
and is worth the overhead of reshape and blocking.
|
| 1114 |
+
|
| 1115 |
+
This blocking includes additional padding, when n is not a multiple of block_n.
|
| 1116 |
+
This padding allows a more efficient microkernel implementation. For BMM, this is only done
|
| 1117 |
+
if reshape would happen anyway, i.e. if the weight tensor is constant, is not contiguous,
|
| 1118 |
+
or is using AMX VNNI layout.
|
| 1119 |
+
2. Packing the weight tensor into a VNNI-friendly shape. For constant input,
|
| 1120 |
+
this is done at the same time as the weight blocking.
|
| 1121 |
+
|
| 1122 |
+
At compile time, the constant weight tensors are blocked and packed. For non-constant tensors (e.g. BMM)
|
| 1123 |
+
which will be blocked (non-contiguous or VNNI-layout tensors), the weight tensor is blocked and packed at runtime.
|
| 1124 |
+
|
| 1125 |
+
CppBmmTemplate overrides the methods get_padded_size, and block_weight in order to accommodate
|
| 1126 |
+
an additional dimension for the batch size and to determine if the weight tensor should be blocked.
|
| 1127 |
+
"""
|
| 1128 |
+
W = inputs[1]
|
| 1129 |
+
new_inputs = list(inputs)
|
| 1130 |
+
if cls.is_woq_int4():
|
| 1131 |
+
assert (
|
| 1132 |
+
len(W.get_size()) == 2
|
| 1133 |
+
if isinstance(W, ir.IRNode)
|
| 1134 |
+
else len(W.shape) == 2
|
| 1135 |
+
)
|
| 1136 |
+
n, k = W.get_size() if isinstance(W, ir.IRNode) else W.shape
|
| 1137 |
+
else:
|
| 1138 |
+
k, n = W.get_size()[-2:] if isinstance(W, ir.IRNode) else W.shape[-2:]
|
| 1139 |
+
_, block_n, _ = micro_gemm.register_blocking
|
| 1140 |
+
new_size, padded_n = cls.get_padded_size(n, block_n, k, should_block_weight)
|
| 1141 |
+
padding = padded_n - n
|
| 1142 |
+
|
| 1143 |
+
if should_block_weight and not cls.is_woq_int4():
|
| 1144 |
+
blocked_w = cls.block_weight(W, new_size, padding)
|
| 1145 |
+
new_inputs[1] = cls.pack_vnni_weight(blocked_w, micro_gemm, new_size)
|
| 1146 |
+
elif should_block_weight:
|
| 1147 |
+
assert cls.is_woq_int4()
|
| 1148 |
+
new_inputs[1] = cls.block_weight(W, new_size, padding)
|
| 1149 |
+
elif isinstance(W, ir.IRNode):
|
| 1150 |
+
# Require W layout to be fixed & contiguous, happens inplace.
|
| 1151 |
+
ir.ExternKernel.require_contiguous(W)
|
| 1152 |
+
|
| 1153 |
+
if not skip_int8_compensation and _is_int8_gemm(new_inputs):
|
| 1154 |
+
BCompensate = None
|
| 1155 |
+
x_w_scale = None
|
| 1156 |
+
|
| 1157 |
+
def _get_compensation_node(W, use_int8_fast_compensation_path):
|
| 1158 |
+
BCompensate = V.graph.add_tensor_constant(
|
| 1159 |
+
V.graph.constants[W.get_name() + "_BMatrixCompens"],
|
| 1160 |
+
W.get_name() + "_BMatrixCompens",
|
| 1161 |
+
)
|
| 1162 |
+
x_w_scale = None
|
| 1163 |
+
if use_int8_fast_compensation_path:
|
| 1164 |
+
x_w_scale = V.graph.add_tensor_constant(
|
| 1165 |
+
V.graph.constants[W.get_name() + "_x_w_compens"],
|
| 1166 |
+
W.get_name() + "_x_w_compens",
|
| 1167 |
+
)
|
| 1168 |
+
return BCompensate, x_w_scale
|
| 1169 |
+
|
| 1170 |
+
if use_int8_fast_compensation_path:
|
| 1171 |
+
# new_inputs has been reordered: [x, w, optional[bias], x_scale, x_zp, w_scale, w_zp]
|
| 1172 |
+
x_scale = new_inputs[-4]
|
| 1173 |
+
x_zp = new_inputs[-3]
|
| 1174 |
+
w_scale = new_inputs[-2]
|
| 1175 |
+
if isinstance(W, ir.IRNode):
|
| 1176 |
+
BCompensate, x_w_scale = _get_compensation_node(
|
| 1177 |
+
W, use_int8_fast_compensation_path
|
| 1178 |
+
)
|
| 1179 |
+
else:
|
| 1180 |
+
# Use the original W, not the blocked_w in new_inputs[1] to calculate BCompensate
|
| 1181 |
+
BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment]
|
| 1182 |
+
assert all(
|
| 1183 |
+
isinstance(item, torch.Tensor)
|
| 1184 |
+
for item in (x_scale, x_zp, w_scale)
|
| 1185 |
+
)
|
| 1186 |
+
BCompensate = BCompensate * x_scale * w_scale * x_zp
|
| 1187 |
+
x_w_scale = x_scale * w_scale
|
| 1188 |
+
new_inputs.append(BCompensate)
|
| 1189 |
+
new_inputs.append(x_w_scale)
|
| 1190 |
+
else:
|
| 1191 |
+
if isinstance(W, ir.IRNode):
|
| 1192 |
+
BCompensate, _ = _get_compensation_node(
|
| 1193 |
+
W, use_int8_fast_compensation_path
|
| 1194 |
+
)
|
| 1195 |
+
else:
|
| 1196 |
+
# Use the original W, not the blocked_w in new_inputs[1] to calculate BCompensate
|
| 1197 |
+
BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment]
|
| 1198 |
+
new_inputs.append(BCompensate)
|
| 1199 |
+
return new_inputs, layout
|
| 1200 |
+
|
| 1201 |
+
@staticmethod
|
| 1202 |
+
def check_if_block_weight(W, micro_gemm):
|
| 1203 |
+
return True
|
| 1204 |
+
|
| 1205 |
+
@classmethod
|
| 1206 |
+
def block_weight(cls, W, new_size, padding):
|
| 1207 |
+
# These are separated into two methods to allow subclasses to override them separately
|
| 1208 |
+
if isinstance(W, ir.IRNode):
|
| 1209 |
+
if W.get_name() in V.graph.constants:
|
| 1210 |
+
# Create a new buffer, representing the constant blocked tensor
|
| 1211 |
+
blocked_w = ir.Buffer(
|
| 1212 |
+
name=W.get_name(), # Borrow the registered buffer name
|
| 1213 |
+
layout=ir.FixedLayout(
|
| 1214 |
+
W.get_device_or_error(),
|
| 1215 |
+
W.get_dtype(),
|
| 1216 |
+
new_size,
|
| 1217 |
+
ir.FlexibleLayout.contiguous_strides(new_size),
|
| 1218 |
+
0,
|
| 1219 |
+
),
|
| 1220 |
+
)
|
| 1221 |
+
else:
|
| 1222 |
+
if not isinstance(W, ir.TensorBox):
|
| 1223 |
+
W = ir.TensorBox(W)
|
| 1224 |
+
permute_dims = list(range(len(new_size)))
|
| 1225 |
+
permute_dims[-2], permute_dims[-3] = permute_dims[-3], permute_dims[-2]
|
| 1226 |
+
permute_size = list(new_size)
|
| 1227 |
+
permute_size[-2], permute_size[-3] = permute_size[-3], permute_size[-2]
|
| 1228 |
+
blocked_w = L.constant_pad_nd(W, (0, padding))
|
| 1229 |
+
blocked_w = L.permute(
|
| 1230 |
+
L.view(blocked_w, permute_size),
|
| 1231 |
+
permute_dims,
|
| 1232 |
+
)
|
| 1233 |
+
else:
|
| 1234 |
+
assert isinstance(W, torch.Tensor)
|
| 1235 |
+
# Pad the weight tensor and reshape it into a 3D blocked shape
|
| 1236 |
+
blocked_size = list(new_size)
|
| 1237 |
+
blocked_size[-2], blocked_size[-3] = blocked_size[-3], blocked_size[-2]
|
| 1238 |
+
blocked_w = (
|
| 1239 |
+
torch.nn.functional.pad(W, (0, padding)) # type: ignore[assignment]
|
| 1240 |
+
.reshape(*blocked_size)
|
| 1241 |
+
.transpose(-3, -2)
|
| 1242 |
+
.contiguous()
|
| 1243 |
+
)
|
| 1244 |
+
return blocked_w
|
| 1245 |
+
|
| 1246 |
+
@classmethod
|
| 1247 |
+
def pack_vnni_weight(cls, W, micro_gemm, new_size):
|
| 1248 |
+
# WOQ INT4 weights are reordered in microkernel so do not pack them here
|
| 1249 |
+
should_pack = (
|
| 1250 |
+
micro_gemm.get_b_layout() != LayoutType.NORMAL
|
| 1251 |
+
and not micro_gemm.is_woq_int4()
|
| 1252 |
+
)
|
| 1253 |
+
|
| 1254 |
+
# These are separated into two methods to allow subclasses to override them separately
|
| 1255 |
+
if isinstance(W, ir.IRNode):
|
| 1256 |
+
if isinstance(W, ir.Buffer) and W.get_name() in V.graph.constants:
|
| 1257 |
+
return W
|
| 1258 |
+
k = new_size[-2]
|
| 1259 |
+
if not isinstance(W, ir.TensorBox):
|
| 1260 |
+
W = ir.TensorBox(W)
|
| 1261 |
+
if should_pack:
|
| 1262 |
+
permute_dims = list(range(len(new_size) + 1))
|
| 1263 |
+
permute_dims[-1], permute_dims[-2] = permute_dims[-2], permute_dims[-1]
|
| 1264 |
+
vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
|
| 1265 |
+
vnni_view_size = list(new_size)
|
| 1266 |
+
vnni_view_size[-2] = k // vnni_size
|
| 1267 |
+
vnni_view_size.insert(-1, vnni_size)
|
| 1268 |
+
W = L.view(
|
| 1269 |
+
L.permute(L.view(W, vnni_view_size), permute_dims),
|
| 1270 |
+
new_size,
|
| 1271 |
+
)
|
| 1272 |
+
W = ir.ExternKernel.realize_input(W)
|
| 1273 |
+
W = ir.ExternKernel.require_contiguous(W)
|
| 1274 |
+
return W
|
| 1275 |
+
else:
|
| 1276 |
+
k = new_size[-2]
|
| 1277 |
+
# Apply VNNI packing to the weight tensor
|
| 1278 |
+
if should_pack:
|
| 1279 |
+
# TODO: Move VNNI weight packing for non-constant tensors into the template,
|
| 1280 |
+
# to improve cache locality and avoid full-tensor copy.
|
| 1281 |
+
layout_str = (
|
| 1282 |
+
"VNNI4"
|
| 1283 |
+
if micro_gemm.get_b_layout() == LayoutType.VNNI4
|
| 1284 |
+
else "VNNI2"
|
| 1285 |
+
)
|
| 1286 |
+
assert micro_gemm.get_b_layout() in [
|
| 1287 |
+
LayoutType.VNNI2,
|
| 1288 |
+
LayoutType.VNNI4,
|
| 1289 |
+
], f"We only support {layout_str} for now"
|
| 1290 |
+
vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
|
| 1291 |
+
assert k % vnni_size == 0, (
|
| 1292 |
+
f"k should be divisible by vnni_size for {layout_str} layout"
|
| 1293 |
+
)
|
| 1294 |
+
vnni_view_size = list(new_size)
|
| 1295 |
+
vnni_view_size[-2] = k // vnni_size
|
| 1296 |
+
vnni_view_size.insert(-1, vnni_size)
|
| 1297 |
+
W = W.view(vnni_view_size).transpose(-1, -2).contiguous().view(new_size)
|
| 1298 |
+
# normalize stride to be "contiguous_strides" per size
|
| 1299 |
+
# this avoids the problems in L.view during template codegen
|
| 1300 |
+
new_stride = [1]
|
| 1301 |
+
for sz in reversed(W.shape[1:]):
|
| 1302 |
+
new_stride.insert(0, new_stride[0] * sz)
|
| 1303 |
+
W = W.as_strided(W.shape, new_stride)
|
| 1304 |
+
return W
|
| 1305 |
+
|
| 1306 |
+
def get_default_reindexers(self, epilogue_nodes):
|
| 1307 |
+
return [None] * len(epilogue_nodes)
|
| 1308 |
+
|
| 1309 |
+
def get_options(
|
| 1310 |
+
self,
|
| 1311 |
+
kernel: CppTemplateKernel,
|
| 1312 |
+
template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
|
| 1313 |
+
flag_template_buffer_has_other_users: Optional[bool] = None,
|
| 1314 |
+
epilogue_nodes: Optional[list[ir.IRNode]] = None,
|
| 1315 |
+
) -> dict[str, Any]:
|
| 1316 |
+
assert len(self.input_nodes) >= 2
|
| 1317 |
+
|
| 1318 |
+
int8_gemm = self.input_nodes[0].get_dtype() in [torch.uint8, torch.int8]
|
| 1319 |
+
x_scale = None
|
| 1320 |
+
x_zp = None
|
| 1321 |
+
w_scale = None
|
| 1322 |
+
w_zp = None
|
| 1323 |
+
inp = None
|
| 1324 |
+
q_group_size_node = None
|
| 1325 |
+
qscale_and_zeros = None
|
| 1326 |
+
if int8_gemm:
|
| 1327 |
+
X, W = self.input_nodes[0], self.input_nodes[1]
|
| 1328 |
+
bias_idx = 2 if self.has_bias else 1
|
| 1329 |
+
inp = self.input_nodes[bias_idx] if self.has_bias else None
|
| 1330 |
+
x_scale = self.input_nodes[bias_idx + 1]
|
| 1331 |
+
x_zp = self.input_nodes[bias_idx + 2]
|
| 1332 |
+
w_scale = self.input_nodes[bias_idx + 3]
|
| 1333 |
+
w_zp = self.input_nodes[bias_idx + 4]
|
| 1334 |
+
Y = self.output_node
|
| 1335 |
+
elif self.is_woq_int4():
|
| 1336 |
+
X, W = self.input_nodes[0], self.input_nodes[1]
|
| 1337 |
+
Y = self.output_node
|
| 1338 |
+
q_group_size_node = self.input_nodes[2]
|
| 1339 |
+
qscale_and_zeros = self.input_nodes[3]
|
| 1340 |
+
else:
|
| 1341 |
+
X, W = self.input_nodes[0], self.input_nodes[1]
|
| 1342 |
+
Y = self.output_node
|
| 1343 |
+
inp = self.input_nodes[2] if self.has_bias else None
|
| 1344 |
+
|
| 1345 |
+
template_buffer_has_other_users = None
|
| 1346 |
+
|
| 1347 |
+
if template_buffer_node is not None:
|
| 1348 |
+
# Use the updated prepacked weight buffer
|
| 1349 |
+
W = template_buffer_node.inputs[1]
|
| 1350 |
+
Y = template_buffer_node
|
| 1351 |
+
|
| 1352 |
+
assert flag_template_buffer_has_other_users is not None
|
| 1353 |
+
template_buffer_has_other_users = flag_template_buffer_has_other_users
|
| 1354 |
+
|
| 1355 |
+
template_buffer = Y
|
| 1356 |
+
gemm_output_buffer = template_buffer
|
| 1357 |
+
|
| 1358 |
+
epilogues: list[ir.IRNode] = []
|
| 1359 |
+
reindexers: list[Optional[Callable[[list[Any]], list[Any]]]] = []
|
| 1360 |
+
epilogue_creators: list[Callable[[ir.Buffer], ir.Pointwise]] = []
|
| 1361 |
+
fake_buffers: list[ir.Buffer] = []
|
| 1362 |
+
Y_aliases: OrderedSet[str] = OrderedSet()
|
| 1363 |
+
|
| 1364 |
+
use_local_acc = (
|
| 1365 |
+
self.layout.dtype != torch.float
|
| 1366 |
+
or template_buffer_has_other_users
|
| 1367 |
+
or int8_gemm
|
| 1368 |
+
or self.padded_n != self.n
|
| 1369 |
+
or self.maybe_k_slicing()
|
| 1370 |
+
or (epilogue_nodes and epilogue_nodes[-1].get_dtype() != self.layout.dtype)
|
| 1371 |
+
)
|
| 1372 |
+
|
| 1373 |
+
# TODO(jgong5): for int8 gemm, bias-add is handled outside of gemm template,
|
| 1374 |
+
# but we'd better move it here to align with fp.
|
| 1375 |
+
if inp is not None and self.beta != 0 and not int8_gemm:
|
| 1376 |
+
# add an epilogue for bias add
|
| 1377 |
+
def _bias_add_epilogue(buf):
|
| 1378 |
+
return create_epilogue_with_attr(
|
| 1379 |
+
buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype
|
| 1380 |
+
)
|
| 1381 |
+
|
| 1382 |
+
epilogue_creators.append(_bias_add_epilogue)
|
| 1383 |
+
|
| 1384 |
+
if self.epilogue_creator is not None:
|
| 1385 |
+
epilogue_creators.append(self.epilogue_creator)
|
| 1386 |
+
|
| 1387 |
+
# When the GEMM output buffer is localized but it has users other than the epilogue nodes,
|
| 1388 |
+
# we need to copy the value in the GEMM output local buffer to a global buffer.
|
| 1389 |
+
def need_copy_from_local_to_global_buffer_epilogue(
|
| 1390 |
+
use_local_acc, template_buffer_has_other_users, epilogue_creators
|
| 1391 |
+
):
|
| 1392 |
+
# The GEMM output buffer is a global buffer, thus copy is not needed.
|
| 1393 |
+
if not use_local_acc:
|
| 1394 |
+
return False
|
| 1395 |
+
|
| 1396 |
+
# The possible value of template_buffer_has_other_users is (None, False, True)
|
| 1397 |
+
# It is None when generating the gemm template during autotune and it will have value during scheduler codegen.
|
| 1398 |
+
# extra copy_from_local_to_global_buffer_epilogue is not needed in either of the below two cases:
|
| 1399 |
+
# 1. template_buffer_has_other_users is None (i.e. when doing the codegen during autotune)
|
| 1400 |
+
# 2. template_buffer_has_other_users is False, which means it's safe to keep the value in the
|
| 1401 |
+
# GEMM output buffer in local buffer only (no users outside of the epilogues will use its value).
|
| 1402 |
+
if not template_buffer_has_other_users:
|
| 1403 |
+
return False
|
| 1404 |
+
|
| 1405 |
+
# When bias is not None or self.epilogue_creator is not None,
|
| 1406 |
+
# there will be epilogue_creators after the GEMM.
|
| 1407 |
+
# The GEMM output buffer is localized while
|
| 1408 |
+
# the output buffer of the epilogue_creators is a global buffer.
|
| 1409 |
+
if epilogue_creators:
|
| 1410 |
+
return False
|
| 1411 |
+
|
| 1412 |
+
return True
|
| 1413 |
+
|
| 1414 |
+
if need_copy_from_local_to_global_buffer_epilogue(
|
| 1415 |
+
use_local_acc, template_buffer_has_other_users, epilogue_creators
|
| 1416 |
+
):
|
| 1417 |
+
|
| 1418 |
+
def copy_from_local_to_global_buffer_epilogue(input_buffer: ir.Buffer):
|
| 1419 |
+
dtype = self.layout.dtype
|
| 1420 |
+
input_loader = input_buffer.make_loader()
|
| 1421 |
+
|
| 1422 |
+
def copy_inner(index):
|
| 1423 |
+
input = input_loader(index)
|
| 1424 |
+
result = ops.to_dtype(input, dtype)
|
| 1425 |
+
return result
|
| 1426 |
+
|
| 1427 |
+
return ir.Pointwise(
|
| 1428 |
+
device=input_buffer.get_device_or_error(),
|
| 1429 |
+
dtype=self.layout.dtype,
|
| 1430 |
+
inner_fn=copy_inner,
|
| 1431 |
+
ranges=input_buffer.get_size(),
|
| 1432 |
+
)
|
| 1433 |
+
|
| 1434 |
+
epilogue_creators.append(copy_from_local_to_global_buffer_epilogue)
|
| 1435 |
+
|
| 1436 |
+
# NOTE [How CPP GEMM template epilogues are organized]
|
| 1437 |
+
# gemm_output_buffer
|
| 1438 |
+
# --> zero or more in-template epilogues (created by `epilogue_creators`) -->
|
| 1439 |
+
# template_buffer
|
| 1440 |
+
# --> zero or more out-of-template epilogues (`epilogue_nodes`) -->
|
| 1441 |
+
# Y
|
| 1442 |
+
if epilogue_creators:
|
| 1443 |
+
assert isinstance(template_buffer, ir.IRNode)
|
| 1444 |
+
gemm_output_name = f"{template_buffer.get_name()}_GemmOut"
|
| 1445 |
+
gemm_output_buffer = ir.Buffer(
|
| 1446 |
+
name=gemm_output_name, layout=template_buffer.layout
|
| 1447 |
+
)
|
| 1448 |
+
current_input_buffer = gemm_output_buffer
|
| 1449 |
+
for i, creator in enumerate(epilogue_creators):
|
| 1450 |
+
if i == len(epilogue_creators) - 1:
|
| 1451 |
+
buffer_name = template_buffer.get_name()
|
| 1452 |
+
else:
|
| 1453 |
+
buffer_name = f"{gemm_output_name}_epilogue_{i}"
|
| 1454 |
+
epilogues.append(
|
| 1455 |
+
ir.ComputedBuffer(
|
| 1456 |
+
name=buffer_name,
|
| 1457 |
+
layout=template_buffer.layout,
|
| 1458 |
+
data=creator(current_input_buffer),
|
| 1459 |
+
)
|
| 1460 |
+
)
|
| 1461 |
+
fake_buffers.append(current_input_buffer)
|
| 1462 |
+
Y_aliases.add(current_input_buffer.get_name())
|
| 1463 |
+
reindexers.append(None)
|
| 1464 |
+
if i < len(epilogue_creators) - 1:
|
| 1465 |
+
current_input_buffer = ir.Buffer(
|
| 1466 |
+
name=buffer_name, layout=template_buffer.layout
|
| 1467 |
+
)
|
| 1468 |
+
|
| 1469 |
+
assert isinstance(Y, (ir.Buffer, ir.ReinterpretView))
|
| 1470 |
+
Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
|
| 1471 |
+
|
| 1472 |
+
if epilogue_nodes:
|
| 1473 |
+
if not template_buffer_has_other_users:
|
| 1474 |
+
assert isinstance(template_buffer, ir.IRNode)
|
| 1475 |
+
Y_aliases.add(template_buffer.get_name())
|
| 1476 |
+
epilogues.extend(epilogue_nodes)
|
| 1477 |
+
assert Y.get_numel() == epilogues[-1].get_numel()
|
| 1478 |
+
Y = cast(ir.Buffer, epilogues[-1])
|
| 1479 |
+
assert isinstance(template_buffer, ir.Buffer)
|
| 1480 |
+
Y_2d, reindexers = gen_2d_view_of_epilogue_buf(
|
| 1481 |
+
Y,
|
| 1482 |
+
template_buffer,
|
| 1483 |
+
epilogue_nodes,
|
| 1484 |
+
reindexers,
|
| 1485 |
+
default_reindexers=self.get_default_reindexers(epilogue_nodes),
|
| 1486 |
+
)
|
| 1487 |
+
|
| 1488 |
+
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
|
| 1489 |
+
X.get_dtype()
|
| 1490 |
+
)
|
| 1491 |
+
micro_gemm = create_micro_gemm(
|
| 1492 |
+
f"{kernel.kernel_name}_micro_gemm",
|
| 1493 |
+
self.m,
|
| 1494 |
+
self.n,
|
| 1495 |
+
self.k,
|
| 1496 |
+
input_dtype=X.get_dtype(),
|
| 1497 |
+
input2_dtype=W.get_dtype(),
|
| 1498 |
+
output_dtype=output_dtype,
|
| 1499 |
+
compute_dtype=compute_dtype,
|
| 1500 |
+
alpha=self.alpha,
|
| 1501 |
+
num_threads=self.num_threads,
|
| 1502 |
+
use_ref=not self.is_woq_int4(),
|
| 1503 |
+
q_group_size=self.q_group_size(),
|
| 1504 |
+
)
|
| 1505 |
+
assert micro_gemm is not None
|
| 1506 |
+
micro_gemm.use_local_vnni_blocking(not self.should_block_weights)
|
| 1507 |
+
assert self.register_blocking == micro_gemm.register_blocking
|
| 1508 |
+
self.log_blockings()
|
| 1509 |
+
if isinstance(micro_gemm, CppMicroGemmAMX):
|
| 1510 |
+
counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
|
| 1511 |
+
if isinstance(micro_gemm, CppMicroBrgemm):
|
| 1512 |
+
counters["inductor"]["cpp_micro_brgemm_counter"] += 1
|
| 1513 |
+
|
| 1514 |
+
L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes
|
| 1515 |
+
assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
|
| 1516 |
+
|
| 1517 |
+
L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes
|
| 1518 |
+
assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
|
| 1519 |
+
|
| 1520 |
+
options = dict(
|
| 1521 |
+
X=X,
|
| 1522 |
+
W=W,
|
| 1523 |
+
inp=inp,
|
| 1524 |
+
Y=Y,
|
| 1525 |
+
N=self.n,
|
| 1526 |
+
K=self.k,
|
| 1527 |
+
PADDED_N=self.padded_n,
|
| 1528 |
+
GemmOut=gemm_output_buffer,
|
| 1529 |
+
aliases={alias: Y.get_name() for alias in Y_aliases},
|
| 1530 |
+
beta=self.beta,
|
| 1531 |
+
alpha=self.alpha,
|
| 1532 |
+
num_threads=self.num_threads,
|
| 1533 |
+
micro_gemm=micro_gemm,
|
| 1534 |
+
is_dynamic_M=self.is_dynamic_M,
|
| 1535 |
+
template=self,
|
| 1536 |
+
kernel=kernel,
|
| 1537 |
+
export_declaration=get_export_declaration(),
|
| 1538 |
+
epilogue_nodes=epilogues,
|
| 1539 |
+
reindexers=reindexers,
|
| 1540 |
+
Y_2d=Y_2d,
|
| 1541 |
+
use_local_acc=use_local_acc,
|
| 1542 |
+
maybe_k_slicing=self.maybe_k_slicing(),
|
| 1543 |
+
x_scale=x_scale,
|
| 1544 |
+
x_zp=x_zp,
|
| 1545 |
+
w_scale=w_scale,
|
| 1546 |
+
w_zp=w_zp,
|
| 1547 |
+
acc_buf_dtype=torch.int32 if int8_gemm else torch.float,
|
| 1548 |
+
DTYPE_TO_CPP=DTYPE_TO_CPP,
|
| 1549 |
+
L1_cache_size=L1_cache_size,
|
| 1550 |
+
L2_cache_size=L2_cache_size,
|
| 1551 |
+
config=config,
|
| 1552 |
+
fake_buffers=fake_buffers,
|
| 1553 |
+
is_woq_int4=self.is_woq_int4(),
|
| 1554 |
+
q_group_size=q_group_size_node,
|
| 1555 |
+
qscale_and_zeros=qscale_and_zeros,
|
| 1556 |
+
)
|
| 1557 |
+
return options
|
| 1558 |
+
|
| 1559 |
+
def is_int8_woq_gemm_small_m_dim(
|
| 1560 |
+
self,
|
| 1561 |
+
X: ir.ReinterpretView,
|
| 1562 |
+
W: ir.ReinterpretView,
|
| 1563 |
+
N,
|
| 1564 |
+
K,
|
| 1565 |
+
micro_gemm,
|
| 1566 |
+
):
|
| 1567 |
+
"""Use SMALL_M_GEMM_TEMPLATE"""
|
| 1568 |
+
return (
|
| 1569 |
+
isinstance(micro_gemm, CppMicroGemmFP32Vec)
|
| 1570 |
+
and is_int8_woq_gemm_small_m_dim_corner_case(
|
| 1571 |
+
micro_gemm, X.get_size()[0], N, K
|
| 1572 |
+
)
|
| 1573 |
+
and X.get_dtype() is torch.bfloat16
|
| 1574 |
+
and W.get_dtype() is torch.int8
|
| 1575 |
+
)
|
| 1576 |
+
|
| 1577 |
+
def render( # type: ignore[override, return]
|
| 1578 |
+
self,
|
| 1579 |
+
kernel: CppTemplateKernel,
|
| 1580 |
+
template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
|
| 1581 |
+
flag_template_buffer_has_other_users: Optional[bool] = None,
|
| 1582 |
+
epilogue_nodes: Optional[list[ir.IRNode]] = None,
|
| 1583 |
+
**kwargs,
|
| 1584 |
+
) -> str:
|
| 1585 |
+
options = self.get_options(
|
| 1586 |
+
kernel=kernel,
|
| 1587 |
+
template_buffer_node=template_buffer_node,
|
| 1588 |
+
flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
|
| 1589 |
+
epilogue_nodes=epilogue_nodes,
|
| 1590 |
+
)
|
| 1591 |
+
self.render_options = options
|
| 1592 |
+
|
| 1593 |
+
with contextlib.ExitStack() as stack:
|
| 1594 |
+
for buf in options["fake_buffers"]:
|
| 1595 |
+
stack.enter_context(
|
| 1596 |
+
patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
|
| 1597 |
+
)
|
| 1598 |
+
if not options["is_dynamic_M"] and self.is_int8_woq_gemm_small_m_dim(
|
| 1599 |
+
options["X"],
|
| 1600 |
+
options["W"],
|
| 1601 |
+
options["N"],
|
| 1602 |
+
options["K"],
|
| 1603 |
+
options["micro_gemm"],
|
| 1604 |
+
):
|
| 1605 |
+
template_str = SMALL_M_GEMM_TEMPLATE
|
| 1606 |
+
else:
|
| 1607 |
+
template_str = GEMM_TEMPLATE
|
| 1608 |
+
return self._template_from_string(template_str).render(**options)
|
| 1609 |
+
|
| 1610 |
+
def codegen_blocks(
|
| 1611 |
+
self,
|
| 1612 |
+
num_threads,
|
| 1613 |
+
N,
|
| 1614 |
+
K,
|
| 1615 |
+
micro_gemm,
|
| 1616 |
+
is_dynamic_M,
|
| 1617 |
+
kernel,
|
| 1618 |
+
GemmOut,
|
| 1619 |
+
config,
|
| 1620 |
+
L1_cache_size,
|
| 1621 |
+
L2_cache_size,
|
| 1622 |
+
X,
|
| 1623 |
+
W,
|
| 1624 |
+
):
|
| 1625 |
+
options = dict(
|
| 1626 |
+
num_threads=num_threads,
|
| 1627 |
+
N=N,
|
| 1628 |
+
K=K,
|
| 1629 |
+
micro_gemm=micro_gemm,
|
| 1630 |
+
is_dynamic_M=is_dynamic_M,
|
| 1631 |
+
kernel=kernel,
|
| 1632 |
+
GemmOut=GemmOut,
|
| 1633 |
+
config=config,
|
| 1634 |
+
L1_cache_size=L1_cache_size,
|
| 1635 |
+
L2_cache_size=L2_cache_size,
|
| 1636 |
+
template=self,
|
| 1637 |
+
X=X,
|
| 1638 |
+
W=W,
|
| 1639 |
+
is_woq_int4=self.is_woq_int4(),
|
| 1640 |
+
)
|
| 1641 |
+
template_str = GEMM_TEMPLATE_INIT_BLOCKING_BASIC_BLOCK
|
| 1642 |
+
if not (
|
| 1643 |
+
not is_dynamic_M
|
| 1644 |
+
and self.is_int8_woq_gemm_small_m_dim(X, W, N, K, micro_gemm)
|
| 1645 |
+
):
|
| 1646 |
+
template_str += GEMM_TEMPLATE_INIT_BLOCKING_EXTENDED
|
| 1647 |
+
return self._template_from_string(template_str).render(options)
|
| 1648 |
+
|
| 1649 |
+
def codegen_microkernel_def(self):
|
| 1650 |
+
return self._template_from_string(GEMM_TEMPLATE_MICROKERNEL_DEF).render(
|
| 1651 |
+
self.render_options
|
| 1652 |
+
)
|
| 1653 |
+
|
| 1654 |
+
def codegen_gemm_stub_def(self):
|
| 1655 |
+
microkernel = self.codegen_microkernel_def()
|
| 1656 |
+
return microkernel + self._template_from_string(GEMM_TEMPLATE_STUB_DEF).render(
|
| 1657 |
+
self.render_options
|
| 1658 |
+
)
|
| 1659 |
+
|
| 1660 |
+
def codegen_multi_threads_params(self):
|
| 1661 |
+
return self._template_from_string(GEMM_TEMPLATE_MULTI_THREADS_PARAMS).render()
|
| 1662 |
+
|
| 1663 |
+
def codegen_single_thread_params(self, is_dynamic_M):
|
| 1664 |
+
options = dict(
|
| 1665 |
+
is_dynamic_M=is_dynamic_M,
|
| 1666 |
+
)
|
| 1667 |
+
return self._template_from_string(GEMM_TEMPLATE_SINGLE_THREAD_PARAMS).render(
|
| 1668 |
+
options
|
| 1669 |
+
)
|
| 1670 |
+
|
| 1671 |
+
def codegen_m_loop_params(self):
|
| 1672 |
+
return self._template_from_string(GEMM_TEMPLATE_M_LOOP_PARAMS).render()
|
| 1673 |
+
|
| 1674 |
+
def codegen_n_loop_params(self):
|
| 1675 |
+
return self._template_from_string(GEMM_TEMPLATE_N_LOOP_PARAMS).render()
|
| 1676 |
+
|
| 1677 |
+
@classmethod
|
| 1678 |
+
def is_woq_int4(cls):
|
| 1679 |
+
return False
|
| 1680 |
+
|
| 1681 |
+
@classmethod
|
| 1682 |
+
def q_group_size(cls):
|
| 1683 |
+
return None
|
| 1684 |
+
|
| 1685 |
+
|
| 1686 |
+
class CppWoqInt4GemmTemplateMeta(type):
|
| 1687 |
+
def __getitem__(cls, q_group_size):
|
| 1688 |
+
class CppWoqInt4GemmTemplateInstance(CppGemmTemplate):
|
| 1689 |
+
def __init__(
|
| 1690 |
+
self,
|
| 1691 |
+
*args,
|
| 1692 |
+
**kwargs,
|
| 1693 |
+
) -> None:
|
| 1694 |
+
super().__init__(
|
| 1695 |
+
*args,
|
| 1696 |
+
**kwargs,
|
| 1697 |
+
)
|
| 1698 |
+
|
| 1699 |
+
@classmethod
|
| 1700 |
+
def is_woq_int4(cls):
|
| 1701 |
+
return True
|
| 1702 |
+
|
| 1703 |
+
@classmethod
|
| 1704 |
+
def q_group_size(cls):
|
| 1705 |
+
return q_group_size
|
| 1706 |
+
|
| 1707 |
+
@staticmethod
|
| 1708 |
+
def check_if_block_weight(W, micro_gemm):
|
| 1709 |
+
# For WOQ INT4, weight is already packed
|
| 1710 |
+
# However, for AMX microkernel, we want to change the blocking of weight
|
| 1711 |
+
from .cpp_micro_gemm import CppMicroGemmWoQInt4Amx
|
| 1712 |
+
|
| 1713 |
+
return isinstance(micro_gemm, CppMicroGemmWoQInt4Amx)
|
| 1714 |
+
|
| 1715 |
+
@classmethod
|
| 1716 |
+
def block_weight(cls, W, new_size, padding):
|
| 1717 |
+
# This method is called only if AMX microkernels are used.
|
| 1718 |
+
# In this case, we unpack and repack weight so that block_n=32
|
| 1719 |
+
# the format of packed weight is described here:
|
| 1720 |
+
# https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L583
|
| 1721 |
+
if isinstance(W, ir.IRNode):
|
| 1722 |
+
# in this case, we do nothing
|
| 1723 |
+
ir.ExternKernel.require_contiguous(W)
|
| 1724 |
+
blocked_w = W
|
| 1725 |
+
else:
|
| 1726 |
+
# in this case, we unpack and repack weight
|
| 1727 |
+
assert isinstance(W, torch.Tensor)
|
| 1728 |
+
assert W.dim() == 2
|
| 1729 |
+
N = W.size(0)
|
| 1730 |
+
K = W.size(-1) * 2
|
| 1731 |
+
G = cls.q_group_size()
|
| 1732 |
+
# x and qscales_and_zeros are in bfloat16 instead of float to use the optimized kernel
|
| 1733 |
+
# so that the unpacking process is faster
|
| 1734 |
+
x = torch.eye(K).bfloat16()
|
| 1735 |
+
# Here we use scale=1 and qzero=8 because we want to unpack weight
|
| 1736 |
+
# without dequantizing it. The qzero here is 8 instead of 0 because
|
| 1737 |
+
# int4 values are converted to [-7, 8] in the _weight_int4pack_mm_for_cpu kernel:
|
| 1738 |
+
# https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L95
|
| 1739 |
+
qscales_and_zeros = (
|
| 1740 |
+
torch.tensor([1.0, 8.0])
|
| 1741 |
+
.bfloat16()
|
| 1742 |
+
.expand(K // G, N, 2)
|
| 1743 |
+
.contiguous()
|
| 1744 |
+
)
|
| 1745 |
+
# shape: [K, N]
|
| 1746 |
+
unpacked_w = torch.ops.aten._weight_int4pack_mm_for_cpu(
|
| 1747 |
+
x,
|
| 1748 |
+
W,
|
| 1749 |
+
G,
|
| 1750 |
+
qscales_and_zeros,
|
| 1751 |
+
).to(torch.uint8)
|
| 1752 |
+
block_n = 32
|
| 1753 |
+
# shape: [N // block_n, K, block_n]
|
| 1754 |
+
w_blocked = (
|
| 1755 |
+
unpacked_w.view(K, N // block_n, block_n)
|
| 1756 |
+
.permute(1, 0, 2)
|
| 1757 |
+
.contiguous()
|
| 1758 |
+
)
|
| 1759 |
+
# pack 2 int4 -> 1 int8
|
| 1760 |
+
# block_n: [a0, a1, ..., a15, b0, b1, ..., b15]
|
| 1761 |
+
# -> [(a0 & 0xf) | (b0 << 4), (a1 & 0xf) | (b1 << 4), ...]
|
| 1762 |
+
# shape: [N // block_n, K, 2, block_n // 2]
|
| 1763 |
+
w_blocked = w_blocked.view(N // block_n, K, 2, block_n // 2)
|
| 1764 |
+
# shape: [N // block_n, K, block_n // 2]
|
| 1765 |
+
w_blocked_packed = (w_blocked[:, :, 0, :] & 0xF) | (
|
| 1766 |
+
w_blocked[:, :, 1, :] << 4
|
| 1767 |
+
)
|
| 1768 |
+
# shape: [N, K // 2]
|
| 1769 |
+
blocked_w = w_blocked_packed.view(N, K // 2)
|
| 1770 |
+
|
| 1771 |
+
return blocked_w
|
| 1772 |
+
|
| 1773 |
+
return CppWoqInt4GemmTemplateInstance
|
| 1774 |
+
|
| 1775 |
+
|
| 1776 |
+
class CppWoqInt4GemmTemplate(metaclass=CppWoqInt4GemmTemplateMeta):
|
| 1777 |
+
pass
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_grouped_gemm_template.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Any, Callable, cast, Optional, TypeVar
|
| 4 |
+
from unittest.mock import patch
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.utils
|
| 8 |
+
from torch.utils._ordered_set import OrderedSet
|
| 9 |
+
|
| 10 |
+
from ..._dynamo.utils import counters
|
| 11 |
+
from .. import config, ir
|
| 12 |
+
from ..kernel.mm_common import mm_args
|
| 13 |
+
from ..select_algorithm import ChoiceCaller, DataProcessorTemplateWrapper
|
| 14 |
+
from ..utils import parallel_num_threads
|
| 15 |
+
from ..virtualized import V
|
| 16 |
+
from .cpp import get_export_declaration
|
| 17 |
+
from .cpp_gemm_template import (
|
| 18 |
+
CppGemmTemplate,
|
| 19 |
+
expand_bias,
|
| 20 |
+
gen_2d_view_of_epilogue_buf,
|
| 21 |
+
prune_tensors,
|
| 22 |
+
transpose_w,
|
| 23 |
+
)
|
| 24 |
+
from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm
|
| 25 |
+
from .cpp_template_kernel import CppTemplateKernel
|
| 26 |
+
from .cpp_utils import (
|
| 27 |
+
create_epilogue_with_attr,
|
| 28 |
+
DTYPE_TO_CPP,
|
| 29 |
+
GemmBlocking,
|
| 30 |
+
get_gemm_template_output_and_compute_dtype,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
log = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
GEMM_TEMPLATE = r"""
|
| 37 |
+
{{template.header().getvalue()}}
|
| 38 |
+
{{micro_gemm.codegen_define(kernel)}}
|
| 39 |
+
|
| 40 |
+
extern "C" {{export_declaration}}
|
| 41 |
+
{{kernel.def_kernel(inputs=kernel_args, outputs=Y_list, aliases=aliases)}}
|
| 42 |
+
{
|
| 43 |
+
{{kernel.maybe_codegen_profile()}}
|
| 44 |
+
{{ template.codegen_blocks(
|
| 45 |
+
num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOuts[0], config, L1_cache_size, L2_cache_size, X_list[0], W_list[0]
|
| 46 |
+
) }}
|
| 47 |
+
{%- if num_threads > 1 %}
|
| 48 |
+
#pragma omp parallel num_threads({{num_threads}})
|
| 49 |
+
{
|
| 50 |
+
{{ template.codegen_multi_threads_params()|indent(8, false) }}
|
| 51 |
+
{%- else %}
|
| 52 |
+
{
|
| 53 |
+
{{ template.codegen_single_thread_params(is_dynamic_M)|indent(8, false) }}
|
| 54 |
+
{%- endif %}
|
| 55 |
+
{{ micro_gemm.codegen_init(kernel) }}
|
| 56 |
+
{%- set acc_buf_name_list=[] %}
|
| 57 |
+
{%- set acc_buf_name_prefix = "local_acc_buf_" %}
|
| 58 |
+
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
| 59 |
+
{%- set acc_buf_name = acc_buf_name_prefix + gemm_idx|string %}
|
| 60 |
+
{{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
|
| 61 |
+
{%- set acc_buf_name_list=acc_buf_name_list.append(acc_buf_name) %}
|
| 62 |
+
{%- endfor %}
|
| 63 |
+
for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
|
| 64 |
+
{{ template.codegen_m_loop_params()|indent(12, false) }}
|
| 65 |
+
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
|
| 66 |
+
{{ template.codegen_n_loop_params()|indent(16, false) }}
|
| 67 |
+
{%- set acc_list=[] %}
|
| 68 |
+
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
| 69 |
+
{%- set acc_list = acc_list.append( kernel.local_buffers[acc_buf_name_list[gemm_idx]] ) %}
|
| 70 |
+
{{ kernel.reinit_buffer_if_null(acc_buf_name_list[gemm_idx]) }}
|
| 71 |
+
{%- endfor %}
|
| 72 |
+
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
|
| 73 |
+
int64_t k_start = kc * Kr;
|
| 74 |
+
int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
|
| 75 |
+
{%- set tile_X_list=[] %}
|
| 76 |
+
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
| 77 |
+
{%- set tile_X_list = tile_X_list.append( kernel.slice_nd(X_list[gemm_idx], [("m_start", "m_end"), ("k_start", "k_end")]) ) %}
|
| 78 |
+
{%- endfor %}
|
| 79 |
+
for (int64_t nci = nc; nci < nc_block_end; nci++) {
|
| 80 |
+
{%- set tile_W_3d_list=[] %}
|
| 81 |
+
{%- set tile_W_list=[] %}
|
| 82 |
+
{%- set acc_slice_list=[] %}
|
| 83 |
+
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
| 84 |
+
{%- set acc_slice_list = acc_slice_list.append(
|
| 85 |
+
kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")])
|
| 86 |
+
) %}
|
| 87 |
+
{%- set tile_W_3d_list = tile_W_3d_list.append(
|
| 88 |
+
kernel.slice_nd(W_list[gemm_idx], [("nci", "nci + 1"), ("k_start", "k_end"), ()])
|
| 89 |
+
) %}
|
| 90 |
+
{%- endfor %}
|
| 91 |
+
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
| 92 |
+
{%- set tile_W_list = tile_W_list.append(
|
| 93 |
+
kernel.view(tile_W_3d_list[gemm_idx], ["k_end - k_start", micro_gemm.register_blocking.block_n])
|
| 94 |
+
) %}
|
| 95 |
+
{%- endfor %}
|
| 96 |
+
if (kc == k_block_start) {
|
| 97 |
+
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
| 98 |
+
{{ micro_gemm.codegen_call(
|
| 99 |
+
kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=False
|
| 100 |
+
)|indent(28, false) }}
|
| 101 |
+
{%- endfor %}
|
| 102 |
+
} else {
|
| 103 |
+
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
| 104 |
+
{{ micro_gemm.codegen_call(
|
| 105 |
+
kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=True
|
| 106 |
+
)|indent(28, false) }}
|
| 107 |
+
{%- endfor %}
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
{
|
| 112 |
+
{%- set tile_acc_list = [] %}
|
| 113 |
+
{%- set tile_Y_list = [] %}
|
| 114 |
+
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
| 115 |
+
{%- set tile_acc_list = tile_acc_list.append(
|
| 116 |
+
kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("0", "n_end - n_start")])
|
| 117 |
+
) %}
|
| 118 |
+
{%- set tile_Y_list = tile_Y_list.append(
|
| 119 |
+
kernel.slice_nd(Y_2d_list[gemm_idx], [("m_start", "m_end"), ("n_start", "n_end")])
|
| 120 |
+
) %}
|
| 121 |
+
{%- endfor %}
|
| 122 |
+
{{ kernel.store_outputs(
|
| 123 |
+
tile_Y_list,
|
| 124 |
+
tile_acc_list,
|
| 125 |
+
GemmOuts,
|
| 126 |
+
epilogue_nodes,
|
| 127 |
+
offsets=("m_start", "n_start"),
|
| 128 |
+
reindexers=reindexers,
|
| 129 |
+
multi_output_buffers=multi_output_buffers
|
| 130 |
+
)|indent(20, false)
|
| 131 |
+
}}
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
{{ micro_gemm.codegen_finalize(kernel) }}
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_deduplicated_act(act_mapping: dict[int, ir.IRNode]) -> list[ir.IRNode]:
|
| 142 |
+
act_deduplicated = []
|
| 143 |
+
act_deduplicated_name: OrderedSet[str] = OrderedSet()
|
| 144 |
+
for act_idx in range(len(act_mapping.values())):
|
| 145 |
+
act = act_mapping[act_idx]
|
| 146 |
+
if act.get_name() not in act_deduplicated_name:
|
| 147 |
+
act_deduplicated.append(act)
|
| 148 |
+
act_deduplicated_name.add(act.get_name())
|
| 149 |
+
return act_deduplicated
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class CppGroupedGemmTemplate(CppGemmTemplate):
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
input_nodes: list[ir.IRNode],
|
| 156 |
+
layout: ir.Layout,
|
| 157 |
+
num_threads: int,
|
| 158 |
+
register_blocking: GemmBlocking,
|
| 159 |
+
beta: int = 1,
|
| 160 |
+
alpha: int = 1,
|
| 161 |
+
has_bias: bool = False,
|
| 162 |
+
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
| 163 |
+
act_mapping: Optional[dict[int, ir.IRNode]] = None,
|
| 164 |
+
gemm_grouped_num: int = 1,
|
| 165 |
+
) -> None:
|
| 166 |
+
"""
|
| 167 |
+
Template for Group of GEMMs:
|
| 168 |
+
* Each GEMM has the same dimensions (m, n, k) and the same leading dimensions (lda, ldb, ldc)
|
| 169 |
+
for their A, B, and C matrices.
|
| 170 |
+
* Each GEMM has distinct or shared activations, has distinct weight, has unique bias or no bias, has distinct epilogues.
|
| 171 |
+
* In the current implementation, the outputs of all GEMMs are accumulated using pointwise epilogues.
|
| 172 |
+
This behavior can be extended in the future if needed.
|
| 173 |
+
"""
|
| 174 |
+
super().__init__(
|
| 175 |
+
input_nodes,
|
| 176 |
+
layout,
|
| 177 |
+
num_threads,
|
| 178 |
+
register_blocking,
|
| 179 |
+
beta,
|
| 180 |
+
alpha,
|
| 181 |
+
has_bias,
|
| 182 |
+
epilogue_creator,
|
| 183 |
+
)
|
| 184 |
+
self.act_mapping = act_mapping
|
| 185 |
+
self.gemm_grouped_num = gemm_grouped_num
|
| 186 |
+
self.output_node: list[ir.Buffer] = [
|
| 187 |
+
ir.Buffer(name="buf_out" + str(idx), layout=layout)
|
| 188 |
+
for idx in range(gemm_grouped_num)
|
| 189 |
+
]
|
| 190 |
+
|
| 191 |
+
@classmethod
|
| 192 |
+
def add_choices(
|
| 193 |
+
cls,
|
| 194 |
+
choices: list[ChoiceCaller],
|
| 195 |
+
layout: ir.Layout,
|
| 196 |
+
input_nodes: list[ir.IRNode],
|
| 197 |
+
beta: int = 1,
|
| 198 |
+
alpha: int = 1,
|
| 199 |
+
has_bias: tuple[bool, ...] = (False, False),
|
| 200 |
+
trans_w: bool = False,
|
| 201 |
+
input_indices: Optional[list[int]] = None,
|
| 202 |
+
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
| 203 |
+
act_mapping: Optional[dict[int, ir.IRNode]] = None, # gemm idx to its act buf
|
| 204 |
+
) -> DataProcessorTemplateWrapper:
|
| 205 |
+
# Input nodes order: x, optional[x1], ... w0, w1, ... optional[b0], optional[b1], ...
|
| 206 |
+
gemm_grouped_num = len(has_bias)
|
| 207 |
+
assert act_mapping
|
| 208 |
+
act_deduplicated = get_deduplicated_act(act_mapping)
|
| 209 |
+
wgt_start_idx = len(act_deduplicated)
|
| 210 |
+
bias_start_idx = wgt_start_idx + gemm_grouped_num
|
| 211 |
+
input_indices = list(range(len(input_nodes)))
|
| 212 |
+
|
| 213 |
+
_T = TypeVar("_T", ir.IRNode, torch.Tensor)
|
| 214 |
+
_U = TypeVar("_U", ir.Layout, torch.Tensor)
|
| 215 |
+
|
| 216 |
+
def reorder_and_filter(
|
| 217 |
+
inputs: list[_T],
|
| 218 |
+
layout_or_out: _U,
|
| 219 |
+
) -> tuple[list[_T], _U]:
|
| 220 |
+
assert input_indices is not None, "input_indices must be set"
|
| 221 |
+
return [inputs[idx] for idx in input_indices], layout_or_out
|
| 222 |
+
|
| 223 |
+
new_inputs, new_layout = reorder_and_filter(input_nodes, layout)
|
| 224 |
+
|
| 225 |
+
def maybe_to_dense(
|
| 226 |
+
inputs: list[_T],
|
| 227 |
+
layout_or_out: _U,
|
| 228 |
+
) -> tuple[list[_T], _U]:
|
| 229 |
+
new_inputs = list(inputs)
|
| 230 |
+
for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
|
| 231 |
+
if isinstance(inputs[idx], torch.Tensor):
|
| 232 |
+
W = inputs[idx]
|
| 233 |
+
assert isinstance(W, torch.Tensor), "W must be a torch.Tensor"
|
| 234 |
+
new_inputs[idx] = W.to_dense() if W.is_mkldnn else W
|
| 235 |
+
return new_inputs, layout_or_out
|
| 236 |
+
|
| 237 |
+
def normalize_shapes(
|
| 238 |
+
inputs: list[_T],
|
| 239 |
+
layout_or_out: _U,
|
| 240 |
+
) -> tuple[list[_T], _U]:
|
| 241 |
+
new_inputs: list[_T] = list(inputs)
|
| 242 |
+
if not trans_w:
|
| 243 |
+
return new_inputs, layout_or_out
|
| 244 |
+
X = new_inputs[0]
|
| 245 |
+
for wgt_idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
|
| 246 |
+
new_input = new_inputs[wgt_idx]
|
| 247 |
+
new_inputs[wgt_idx] = transpose_w(new_input, trans_w)
|
| 248 |
+
for bias_idx in range(bias_start_idx, len(new_inputs)):
|
| 249 |
+
new_bias = expand_bias(new_inputs[bias_idx], X)
|
| 250 |
+
assert new_bias is not None
|
| 251 |
+
new_inputs[bias_idx] = new_bias
|
| 252 |
+
return new_inputs, layout_or_out
|
| 253 |
+
|
| 254 |
+
num_threads = parallel_num_threads()
|
| 255 |
+
new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout))
|
| 256 |
+
m, n, k, *_ = mm_args(new_inputs[0], new_inputs[wgt_start_idx])
|
| 257 |
+
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
|
| 258 |
+
new_inputs[0].get_dtype()
|
| 259 |
+
)
|
| 260 |
+
micro_gemm = create_micro_gemm(
|
| 261 |
+
"micro_gemm",
|
| 262 |
+
m,
|
| 263 |
+
n,
|
| 264 |
+
k,
|
| 265 |
+
input_dtype=new_inputs[0].get_dtype(),
|
| 266 |
+
input2_dtype=new_inputs[wgt_start_idx].get_dtype(),
|
| 267 |
+
output_dtype=output_dtype,
|
| 268 |
+
compute_dtype=compute_dtype,
|
| 269 |
+
alpha=alpha,
|
| 270 |
+
num_threads=num_threads,
|
| 271 |
+
)
|
| 272 |
+
assert micro_gemm is not None
|
| 273 |
+
_, block_n, _ = micro_gemm.register_blocking
|
| 274 |
+
new_size, padded_n = cls.get_padded_size(
|
| 275 |
+
n, block_n, k, should_block_weight=True
|
| 276 |
+
)
|
| 277 |
+
padding = padded_n - n
|
| 278 |
+
|
| 279 |
+
def pack_weight(
|
| 280 |
+
inputs: list[_T],
|
| 281 |
+
layout_or_out: _U,
|
| 282 |
+
) -> tuple[list[_T], _U]:
|
| 283 |
+
new_W_list = []
|
| 284 |
+
new_inputs = list(inputs)
|
| 285 |
+
W_list = new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num]
|
| 286 |
+
for W in W_list:
|
| 287 |
+
blocked_w = cls.block_weight(W, new_size, padding)
|
| 288 |
+
new_W_list.append(cls.pack_vnni_weight(blocked_w, micro_gemm, new_size))
|
| 289 |
+
new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = new_W_list
|
| 290 |
+
return new_inputs, layout_or_out
|
| 291 |
+
|
| 292 |
+
def preprocessor(
|
| 293 |
+
inputs: list[_T],
|
| 294 |
+
layout: _U,
|
| 295 |
+
) -> tuple[list[_T], _U]:
|
| 296 |
+
return pack_weight(
|
| 297 |
+
*normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout)))
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
def postprocessor(output: _T) -> _T:
|
| 301 |
+
if isinstance(output, ir.TensorBox):
|
| 302 |
+
template_buffer = ir.InputsKernel.unwrap_storage_for_input(output)
|
| 303 |
+
assert isinstance(template_buffer, ir.CppTemplateBuffer)
|
| 304 |
+
new_input_nodes, _ = reorder_and_filter(input_nodes, layout)
|
| 305 |
+
W_nodes = new_input_nodes[
|
| 306 |
+
wgt_start_idx : wgt_start_idx + gemm_grouped_num
|
| 307 |
+
]
|
| 308 |
+
W_tensor = []
|
| 309 |
+
for W_node in W_nodes:
|
| 310 |
+
assert W_node.get_name() in V.graph.constants
|
| 311 |
+
W_tensor.append(V.graph.constants[W_node.get_name()])
|
| 312 |
+
new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = (
|
| 313 |
+
W_tensor # type: ignore[assignment]
|
| 314 |
+
)
|
| 315 |
+
new_input_nodes, _ = pack_weight(
|
| 316 |
+
*normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
|
| 317 |
+
)
|
| 318 |
+
# Prune unused tensors
|
| 319 |
+
prune_tensors(input_nodes, new_input_nodes)
|
| 320 |
+
for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
|
| 321 |
+
W_packed = new_input_nodes[idx]
|
| 322 |
+
assert isinstance(W_packed, torch.Tensor)
|
| 323 |
+
W_packed_constant = V.graph.add_tensor_constant(W_packed)
|
| 324 |
+
template_buffer.inputs[idx] = (
|
| 325 |
+
ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
|
| 326 |
+
)
|
| 327 |
+
return output
|
| 328 |
+
|
| 329 |
+
template = DataProcessorTemplateWrapper(
|
| 330 |
+
CppGroupedGemmTemplate,
|
| 331 |
+
preprocessor,
|
| 332 |
+
postprocessor,
|
| 333 |
+
input_nodes=input_nodes,
|
| 334 |
+
layout=layout,
|
| 335 |
+
num_threads=num_threads,
|
| 336 |
+
register_blocking=micro_gemm.register_blocking,
|
| 337 |
+
beta=beta,
|
| 338 |
+
alpha=alpha,
|
| 339 |
+
has_bias=has_bias,
|
| 340 |
+
epilogue_creator=epilogue_creator,
|
| 341 |
+
act_mapping=act_mapping,
|
| 342 |
+
gemm_grouped_num=gemm_grouped_num,
|
| 343 |
+
)
|
| 344 |
+
template.maybe_append_choice(choices)
|
| 345 |
+
return template
|
| 346 |
+
|
| 347 |
+
def render( # type: ignore[override,return,no-untyped-def]
|
| 348 |
+
self,
|
| 349 |
+
kernel: CppTemplateKernel,
|
| 350 |
+
template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
|
| 351 |
+
flag_template_buffer_has_other_users: Optional[bool] = None,
|
| 352 |
+
epilogue_nodes: Optional[list[ir.IRNode]] = None,
|
| 353 |
+
**kwargs,
|
| 354 |
+
) -> str:
|
| 355 |
+
assert self.act_mapping
|
| 356 |
+
act_deduplicated = get_deduplicated_act(self.act_mapping)
|
| 357 |
+
wgt_start_idx = len(act_deduplicated)
|
| 358 |
+
bias_start_idx = wgt_start_idx + self.gemm_grouped_num
|
| 359 |
+
X_list = list(self.act_mapping.values())
|
| 360 |
+
W_list = self.input_nodes[wgt_start_idx : wgt_start_idx + self.gemm_grouped_num]
|
| 361 |
+
inp_list = []
|
| 362 |
+
cur_idx = bias_start_idx
|
| 363 |
+
for inp_idx in range(self.gemm_grouped_num):
|
| 364 |
+
inp = None
|
| 365 |
+
if self.has_bias[inp_idx]:
|
| 366 |
+
inp = self.input_nodes[cur_idx]
|
| 367 |
+
cur_idx += 1
|
| 368 |
+
inp_list.append(inp)
|
| 369 |
+
|
| 370 |
+
Y_list = self.output_node
|
| 371 |
+
multi_output_buffers = None
|
| 372 |
+
if template_buffer_node is not None:
|
| 373 |
+
W_list = template_buffer_node.inputs[
|
| 374 |
+
wgt_start_idx : wgt_start_idx + self.gemm_grouped_num
|
| 375 |
+
]
|
| 376 |
+
assert isinstance(template_buffer_node.outputs, list)
|
| 377 |
+
Y_list = template_buffer_node.outputs
|
| 378 |
+
counters["inductor"]["cpp_grouped_gemm_template"] += 1
|
| 379 |
+
multi_output_buffers = template_buffer_node.outputs
|
| 380 |
+
|
| 381 |
+
template_buffer = Y_list[0]
|
| 382 |
+
fake_buffers: list[ir.Buffer] = []
|
| 383 |
+
Y_2d_list = Y_list
|
| 384 |
+
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
|
| 385 |
+
X_list[0].get_dtype()
|
| 386 |
+
)
|
| 387 |
+
micro_gemm = create_micro_gemm(
|
| 388 |
+
f"{kernel.kernel_name}_micro_gemm",
|
| 389 |
+
self.m,
|
| 390 |
+
self.n,
|
| 391 |
+
self.k,
|
| 392 |
+
input_dtype=X_list[0].get_dtype(),
|
| 393 |
+
input2_dtype=W_list[0].get_dtype(),
|
| 394 |
+
output_dtype=output_dtype,
|
| 395 |
+
compute_dtype=compute_dtype,
|
| 396 |
+
alpha=self.alpha,
|
| 397 |
+
num_threads=self.num_threads,
|
| 398 |
+
)
|
| 399 |
+
assert micro_gemm is not None
|
| 400 |
+
assert self.register_blocking == micro_gemm.register_blocking
|
| 401 |
+
self.log_blockings()
|
| 402 |
+
if isinstance(micro_gemm, CppMicroGemmAMX):
|
| 403 |
+
counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
|
| 404 |
+
|
| 405 |
+
L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes
|
| 406 |
+
assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
|
| 407 |
+
|
| 408 |
+
L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes
|
| 409 |
+
assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
|
| 410 |
+
|
| 411 |
+
epilogues: list[ir.IRNode] = []
|
| 412 |
+
reindexers: list[Optional[Callable[[list[Any]], list[Any]]]] = []
|
| 413 |
+
gemm_output_buffers: list[ir.Buffer] = []
|
| 414 |
+
for out_buf_idx in range(self.gemm_grouped_num):
|
| 415 |
+
gemm_output_name = f"{template_buffer.get_name()}_GemmOut" + str(
|
| 416 |
+
out_buf_idx
|
| 417 |
+
)
|
| 418 |
+
gemm_output_buffers.append(
|
| 419 |
+
ir.Buffer(name=gemm_output_name, layout=template_buffer.layout)
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
assert not self.epilogue_creator, (
|
| 423 |
+
"epilogue_creator is not supported yet in Grouped GEMM Template"
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
kernel_args: dict[str, Optional[ir.IRNode]] = {}
|
| 427 |
+
for x_idx in range(wgt_start_idx):
|
| 428 |
+
kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx]
|
| 429 |
+
for w_idx in range(self.gemm_grouped_num):
|
| 430 |
+
kernel_args["W" + str(w_idx)] = W_list[w_idx]
|
| 431 |
+
for inp_idx in range(self.gemm_grouped_num):
|
| 432 |
+
kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx]
|
| 433 |
+
|
| 434 |
+
def _bias_add_epilogue(buf: ir.IRNode, inp: ir.IRNode) -> ir.Pointwise:
|
| 435 |
+
return create_epilogue_with_attr(
|
| 436 |
+
buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
for gemm_idx, inp in enumerate(inp_list):
|
| 440 |
+
if inp:
|
| 441 |
+
buffer_name = Y_list[gemm_idx].get_name()
|
| 442 |
+
epilogues.append(
|
| 443 |
+
ir.ComputedBuffer(
|
| 444 |
+
name=buffer_name,
|
| 445 |
+
layout=template_buffer.layout,
|
| 446 |
+
data=_bias_add_epilogue(gemm_output_buffers[gemm_idx], inp),
|
| 447 |
+
)
|
| 448 |
+
)
|
| 449 |
+
reindexers.append(None)
|
| 450 |
+
|
| 451 |
+
if epilogue_nodes:
|
| 452 |
+
epilogues.extend(epilogue_nodes)
|
| 453 |
+
for epilogue_node in epilogue_nodes:
|
| 454 |
+
Y = cast(ir.Buffer, epilogue_node)
|
| 455 |
+
_, reindexers = gen_2d_view_of_epilogue_buf(
|
| 456 |
+
Y,
|
| 457 |
+
template_buffer,
|
| 458 |
+
[
|
| 459 |
+
epilogue_node,
|
| 460 |
+
],
|
| 461 |
+
reindexers,
|
| 462 |
+
default_reindexers=[
|
| 463 |
+
None,
|
| 464 |
+
],
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
options = dict(
|
| 468 |
+
N=self.n,
|
| 469 |
+
K=self.k,
|
| 470 |
+
PADDED_N=self.padded_n,
|
| 471 |
+
aliases={},
|
| 472 |
+
beta=self.beta,
|
| 473 |
+
alpha=self.alpha,
|
| 474 |
+
num_threads=self.num_threads,
|
| 475 |
+
micro_gemm=micro_gemm,
|
| 476 |
+
is_dynamic_M=self.is_dynamic_M,
|
| 477 |
+
template=self,
|
| 478 |
+
kernel=kernel,
|
| 479 |
+
export_declaration=get_export_declaration(),
|
| 480 |
+
acc_buf_dtype=torch.float,
|
| 481 |
+
DTYPE_TO_CPP=DTYPE_TO_CPP,
|
| 482 |
+
L1_cache_size=L1_cache_size,
|
| 483 |
+
L2_cache_size=L2_cache_size,
|
| 484 |
+
config=config,
|
| 485 |
+
epilogue_nodes=epilogues,
|
| 486 |
+
GemmOuts=gemm_output_buffers,
|
| 487 |
+
reindexers=reindexers,
|
| 488 |
+
kernel_args=kernel_args,
|
| 489 |
+
X_list=X_list,
|
| 490 |
+
W_list=W_list,
|
| 491 |
+
gemm_grouped_num=self.gemm_grouped_num,
|
| 492 |
+
Y_list={"Y" + str(idx): Y for idx, Y in enumerate(Y_list)},
|
| 493 |
+
Y_2d_list=Y_2d_list,
|
| 494 |
+
multi_output_buffers=multi_output_buffers,
|
| 495 |
+
)
|
| 496 |
+
with contextlib.ExitStack() as stack:
|
| 497 |
+
stack.enter_context(
|
| 498 |
+
patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_buffers))
|
| 499 |
+
)
|
| 500 |
+
return self._template_from_string(GEMM_TEMPLATE).render(**options)
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py
ADDED
|
@@ -0,0 +1,2011 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import dataclasses
|
| 3 |
+
import operator
|
| 4 |
+
import sys
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from typing import Callable, Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from .. import cpp_builder, ir
|
| 11 |
+
from ..cpu_vec_isa import (
|
| 12 |
+
pick_vec_isa,
|
| 13 |
+
VecAMX,
|
| 14 |
+
VecAVX2,
|
| 15 |
+
VecAVX512,
|
| 16 |
+
VecISA,
|
| 17 |
+
VecNEON,
|
| 18 |
+
VecSVE256,
|
| 19 |
+
)
|
| 20 |
+
from ..utils import IndentedBuffer, parallel_num_threads
|
| 21 |
+
from ..virtualized import V
|
| 22 |
+
from .common import KernelTemplate
|
| 23 |
+
from .cpp_template_kernel import CppTemplateKernel
|
| 24 |
+
from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LayoutType(Enum):
|
| 28 |
+
NORMAL = 0
|
| 29 |
+
VNNI2 = 1
|
| 30 |
+
VNNI4 = 2
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
_IS_WINDOWS = sys.platform == "win32"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_restrict_keyword() -> str:
|
| 37 |
+
if _IS_WINDOWS:
|
| 38 |
+
# https://learn.microsoft.com/en-us/cpp/cpp/extension-restrict?view=msvc-170
|
| 39 |
+
return "__restrict"
|
| 40 |
+
else:
|
| 41 |
+
return "__restrict__"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class CppMicroGemm:
|
| 45 |
+
"""
|
| 46 |
+
A class that codegens a kernel that computes small-sized matrix multiplication.
|
| 47 |
+
|
| 48 |
+
A micro GEMM kernel is responsible for register blocking, instruction selection,
|
| 49 |
+
and other CPU architecture-specific optimizations.
|
| 50 |
+
|
| 51 |
+
The subclasses need to override `codegen_define` to define the kernel function
|
| 52 |
+
that is called by the code generated by `codegen_call`.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
# TODO(jgong5): support constant shapes and lds as template args.
|
| 56 |
+
DECLARE_KERNEL = r"""
|
| 57 |
+
template <bool accum, bool prefetch=false>
|
| 58 |
+
inline void {{kernel_name}}(
|
| 59 |
+
{%- if kernel_extra_args_declare %}
|
| 60 |
+
{{kernel_extra_args_declare}}
|
| 61 |
+
{%- endif %}
|
| 62 |
+
const {{input_t}}* {{restrict_keyword}} A,
|
| 63 |
+
const {{input2_t}}* {{restrict_keyword}} B,
|
| 64 |
+
{{output_t}}* {{restrict_keyword}} C,
|
| 65 |
+
int64_t M,
|
| 66 |
+
int64_t N,
|
| 67 |
+
int64_t K,
|
| 68 |
+
int64_t lda,
|
| 69 |
+
int64_t ldb,
|
| 70 |
+
int64_t ldc
|
| 71 |
+
)
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
name,
|
| 77 |
+
input_dtype,
|
| 78 |
+
input2_dtype,
|
| 79 |
+
output_dtype,
|
| 80 |
+
compute_dtype,
|
| 81 |
+
register_blocking,
|
| 82 |
+
alpha=1,
|
| 83 |
+
) -> None:
|
| 84 |
+
self.name = name
|
| 85 |
+
self.input_dtype = input_dtype
|
| 86 |
+
assert input2_dtype is not None
|
| 87 |
+
self.input2_dtype = input2_dtype
|
| 88 |
+
self.output_dtype = output_dtype
|
| 89 |
+
self.compute_dtype = compute_dtype
|
| 90 |
+
self.register_blocking = register_blocking
|
| 91 |
+
self.alpha = alpha
|
| 92 |
+
self.pack_vnni_B_locally = False
|
| 93 |
+
|
| 94 |
+
def get_common_options(self):
|
| 95 |
+
if self.input_dtype in [torch.uint8, torch.int8]:
|
| 96 |
+
assert self.compute_dtype == torch.int32
|
| 97 |
+
assert self.output_dtype == torch.int32
|
| 98 |
+
assert self.input2_dtype == torch.int8
|
| 99 |
+
return {
|
| 100 |
+
"torch": torch,
|
| 101 |
+
"kernel_name": self.name,
|
| 102 |
+
"input_dtype": self.input_dtype,
|
| 103 |
+
"input2_dtype": self.input2_dtype,
|
| 104 |
+
"output_dtype": self.output_dtype,
|
| 105 |
+
"compute_dtype": self.compute_dtype,
|
| 106 |
+
"input_t": DTYPE_TO_CPP[self.input_dtype],
|
| 107 |
+
"input2_t": DTYPE_TO_CPP[self.input2_dtype],
|
| 108 |
+
"output_t": DTYPE_TO_CPP[self.output_dtype],
|
| 109 |
+
"compute_t": DTYPE_TO_CPP[self.compute_dtype],
|
| 110 |
+
"alpha": self.alpha,
|
| 111 |
+
"kernel_extra_args_declare": self.get_kernel_extra_args_declare(),
|
| 112 |
+
"int8_gemm": self.input_dtype in [torch.uint8, torch.int8],
|
| 113 |
+
"vnni_size": 4 if self.input_dtype in [torch.uint8, torch.int8] else 2,
|
| 114 |
+
"restrict_keyword": get_restrict_keyword(),
|
| 115 |
+
"pack_vnni_B_locally": self.pack_vnni_B_locally,
|
| 116 |
+
"template": self,
|
| 117 |
+
"is_woq_int4": self.is_woq_int4(),
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
def get_kernel_declaration(self):
|
| 121 |
+
options = self.get_common_options()
|
| 122 |
+
return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options)
|
| 123 |
+
|
| 124 |
+
def get_kernel_extra_args_declare(self) -> str:
|
| 125 |
+
return ""
|
| 126 |
+
|
| 127 |
+
def get_kernel_extra_args(self, **kwargs) -> list[str]:
|
| 128 |
+
return []
|
| 129 |
+
|
| 130 |
+
def codegen_define(self, kernel: CppTemplateKernel) -> str:
|
| 131 |
+
raise NotImplementedError
|
| 132 |
+
|
| 133 |
+
def codegen_call(
|
| 134 |
+
self,
|
| 135 |
+
kernel: CppTemplateKernel,
|
| 136 |
+
A: ir.Buffer,
|
| 137 |
+
B: ir.Buffer,
|
| 138 |
+
C: ir.Buffer,
|
| 139 |
+
accum: bool,
|
| 140 |
+
prefetch: bool = False,
|
| 141 |
+
**kwargs_for_extra_args,
|
| 142 |
+
) -> str:
|
| 143 |
+
"""
|
| 144 |
+
Generate the code for calling the templated kernel that computes
|
| 145 |
+
`C += alpha * A @ B` if `accum` is True, or `C = alpha * A @ B` otherwise.
|
| 146 |
+
"""
|
| 147 |
+
A_ptr = f"&({kernel.index(A, [0, 0])})"
|
| 148 |
+
B_ptr = f"&({kernel.index(B, [0, 0])})"
|
| 149 |
+
C_ptr = f"&({kernel.index(C, [0, 0])})"
|
| 150 |
+
M = kernel.size(C, 0)
|
| 151 |
+
N = kernel.size(C, 1)
|
| 152 |
+
K = kernel.size(A, 1)
|
| 153 |
+
lda = kernel.stride(A, 0)
|
| 154 |
+
ldb = kernel.stride(B, 0)
|
| 155 |
+
ldc = kernel.stride(C, 0)
|
| 156 |
+
res = IndentedBuffer()
|
| 157 |
+
res.writeline(
|
| 158 |
+
f"{self.name}<{value_to_cpp(accum, 'bool')}, {value_to_cpp(prefetch, 'bool')}>("
|
| 159 |
+
)
|
| 160 |
+
with res.indent():
|
| 161 |
+
kwargs_for_extra_args.update({"kernel": kernel})
|
| 162 |
+
extra_args = self.get_kernel_extra_args(**kwargs_for_extra_args)
|
| 163 |
+
for arg in extra_args:
|
| 164 |
+
res.writeline(arg)
|
| 165 |
+
res.writeline(f"{A_ptr},")
|
| 166 |
+
res.writeline(f"{B_ptr},")
|
| 167 |
+
res.writeline(f"{C_ptr},")
|
| 168 |
+
res.writeline(f"{M},")
|
| 169 |
+
res.writeline(f"{N},")
|
| 170 |
+
res.writeline(f"{K},")
|
| 171 |
+
res.writeline(f"{lda},")
|
| 172 |
+
res.writeline(f"{ldb},")
|
| 173 |
+
res.writeline(f"{ldc}")
|
| 174 |
+
res.writeline(");")
|
| 175 |
+
return res.getvalue()
|
| 176 |
+
|
| 177 |
+
def use_local_vnni_blocking(self, should_block_weight: bool):
|
| 178 |
+
self.pack_vnni_B_locally = should_block_weight
|
| 179 |
+
|
| 180 |
+
def codegen_init(
|
| 181 |
+
self,
|
| 182 |
+
kernel: CppTemplateKernel,
|
| 183 |
+
) -> str:
|
| 184 |
+
return ""
|
| 185 |
+
|
| 186 |
+
def codegen_finalize(
|
| 187 |
+
self,
|
| 188 |
+
kernel: CppTemplateKernel,
|
| 189 |
+
) -> str:
|
| 190 |
+
return ""
|
| 191 |
+
|
| 192 |
+
def get_b_layout(self) -> LayoutType:
|
| 193 |
+
return LayoutType.NORMAL
|
| 194 |
+
|
| 195 |
+
ALLOCATE_WEIGHT_BUFFER = r"""
|
| 196 |
+
{%- if is_msvc_compiler %}
|
| 197 |
+
// MSVC doesn't support stack-allocated dynamic-sized arrays, so using heap memory here.
|
| 198 |
+
std::unique_ptr<{{buffer_dtype}}[]> heap_deq_b_buf_ptr(new {{buffer_dtype}}[{{buffer_size}}]);
|
| 199 |
+
{{buffer_dtype}}* {{buffer_name}} = heap_deq_b_buf_ptr.get();
|
| 200 |
+
{%- else %}
|
| 201 |
+
// It's safe to use a stack-allocated array since the blocking strategy would
|
| 202 |
+
// require us to allocate an array that's smaller than the size of L1D cache,
|
| 203 |
+
// and the default per thread max stack size on Linux is quite higher,
|
| 204 |
+
// so we need not worry about stack overflow.
|
| 205 |
+
alignas(4096) {{buffer_dtype}} {{buffer_name}}[{{buffer_size}}];
|
| 206 |
+
{%- endif %}
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
def codegen_allocate_weight_buffer(
|
| 210 |
+
self, buffer_name: str, buffer_dtype: str, *size_args
|
| 211 |
+
) -> str:
|
| 212 |
+
buffer_size = " * ".join(map(str, size_args))
|
| 213 |
+
return KernelTemplate._template_from_string(self.ALLOCATE_WEIGHT_BUFFER).render(
|
| 214 |
+
dict(
|
| 215 |
+
buffer_name=buffer_name,
|
| 216 |
+
buffer_dtype=buffer_dtype,
|
| 217 |
+
buffer_size=buffer_size,
|
| 218 |
+
is_msvc_compiler=cpp_builder.is_msvc_cl(),
|
| 219 |
+
)
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def is_woq_int4(self):
|
| 223 |
+
return False
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
@dataclasses.dataclass
|
| 227 |
+
class CppMicroGemmConfig:
|
| 228 |
+
input_dtype: torch.dtype
|
| 229 |
+
input2_dtype: torch.dtype
|
| 230 |
+
output_dtype: torch.dtype
|
| 231 |
+
compute_dtype: torch.dtype
|
| 232 |
+
vec_isa_cls: type[VecISA]
|
| 233 |
+
register_blocking: GemmBlocking
|
| 234 |
+
extra_check: Optional[Callable[..., bool]] = None
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
micro_gemm_configs: dict[type[CppMicroGemm], list[CppMicroGemmConfig]] = {}
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def register_micro_gemm(*configs):
|
| 241 |
+
def inner(cls):
|
| 242 |
+
assert cls not in micro_gemm_configs, (
|
| 243 |
+
f"Duplicate micro_gemm registration for {cls}"
|
| 244 |
+
)
|
| 245 |
+
assert len(configs) > 0, f"No micro_gemm configs provided for {cls}"
|
| 246 |
+
micro_gemm_configs[cls] = list(configs)
|
| 247 |
+
return cls
|
| 248 |
+
|
| 249 |
+
return inner
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def generate_gemm_config(
|
| 253 |
+
vec_isa_cls,
|
| 254 |
+
register_blockings,
|
| 255 |
+
input_dtype=torch.float,
|
| 256 |
+
input2_dtype=None,
|
| 257 |
+
output_dtype=None,
|
| 258 |
+
compute_dtype=None,
|
| 259 |
+
extra_check=None,
|
| 260 |
+
):
|
| 261 |
+
if output_dtype is None:
|
| 262 |
+
output_dtype = input_dtype
|
| 263 |
+
if compute_dtype is None:
|
| 264 |
+
compute_dtype = output_dtype
|
| 265 |
+
if input2_dtype is None:
|
| 266 |
+
input2_dtype = input_dtype
|
| 267 |
+
return [
|
| 268 |
+
CppMicroGemmConfig(
|
| 269 |
+
input_dtype,
|
| 270 |
+
input2_dtype,
|
| 271 |
+
output_dtype,
|
| 272 |
+
compute_dtype,
|
| 273 |
+
vec_isa_cls,
|
| 274 |
+
GemmBlocking(*blocking),
|
| 275 |
+
extra_check,
|
| 276 |
+
)
|
| 277 |
+
for blocking in register_blockings
|
| 278 |
+
]
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class CppMicroGemmRef(CppMicroGemm):
|
| 282 |
+
"""
|
| 283 |
+
A reference implementation of the CppMicroGemm class with naive C++ code.
|
| 284 |
+
It is used for correctness debugging.
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
TEMPLATE_ENTRY = r"""
|
| 288 |
+
{{declare_kernel}} {
|
| 289 |
+
for (int64_t m = 0; m < M; ++m) {
|
| 290 |
+
for (int64_t n = 0; n < N; ++n) {
|
| 291 |
+
{{compute_t}} result = accum ? C[m * ldc + n] : 0;
|
| 292 |
+
for (int64_t k = 0; k < K; ++k) {
|
| 293 |
+
result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}};
|
| 294 |
+
}
|
| 295 |
+
C[m * ldc + n] = result;
|
| 296 |
+
}
|
| 297 |
+
}
|
| 298 |
+
}
|
| 299 |
+
"""
|
| 300 |
+
|
| 301 |
+
def __init__(
|
| 302 |
+
self, name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha
|
| 303 |
+
) -> None:
|
| 304 |
+
super().__init__(
|
| 305 |
+
name,
|
| 306 |
+
input_dtype,
|
| 307 |
+
input2_dtype,
|
| 308 |
+
output_dtype,
|
| 309 |
+
compute_dtype,
|
| 310 |
+
GemmBlocking(1, 1, 1),
|
| 311 |
+
alpha,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
def codegen_define(self, kernel: CppTemplateKernel) -> str:
|
| 315 |
+
options = {
|
| 316 |
+
"declare_kernel": self.get_kernel_declaration(),
|
| 317 |
+
**self.get_common_options(),
|
| 318 |
+
}
|
| 319 |
+
return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k):
|
| 323 |
+
return (
|
| 324 |
+
k % config.register_blocking.block_k == 0
|
| 325 |
+
and n % config.register_blocking.block_n == 0
|
| 326 |
+
and m < 16
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
# extra check for small M dimension for int8 WoQ case
|
| 331 |
+
def check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs):
|
| 332 |
+
return is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k) and not kwargs.get(
|
| 333 |
+
"dynamic_M", False
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
# For int8 WoQ GEMM with small M, we use different blockings that shouldn't be used otherwise
|
| 338 |
+
def do_not_use_with_small_m_for_int8_woq(config, m, n, k, alpha, num_threads, **kwargs):
|
| 339 |
+
return not check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
@register_micro_gemm(
|
| 343 |
+
*generate_gemm_config(
|
| 344 |
+
VecAVX512,
|
| 345 |
+
[(8, 48, 1), (8, 32, 1), (16, 16, 1)],
|
| 346 |
+
input_dtype=torch.float,
|
| 347 |
+
),
|
| 348 |
+
*generate_gemm_config(
|
| 349 |
+
VecAVX512,
|
| 350 |
+
[(8, 48, 1), (8, 32, 1), (16, 16, 1)],
|
| 351 |
+
input_dtype=torch.bfloat16,
|
| 352 |
+
output_dtype=torch.float,
|
| 353 |
+
),
|
| 354 |
+
*generate_gemm_config(
|
| 355 |
+
VecAVX512,
|
| 356 |
+
[(8, 48, 1), (8, 32, 1), (16, 16, 1)],
|
| 357 |
+
input_dtype=torch.half,
|
| 358 |
+
output_dtype=torch.float,
|
| 359 |
+
),
|
| 360 |
+
*generate_gemm_config(
|
| 361 |
+
VecAVX512,
|
| 362 |
+
[(8, 48, 1), (8, 32, 1), (16, 16, 1)],
|
| 363 |
+
input_dtype=torch.bfloat16,
|
| 364 |
+
input2_dtype=torch.int8,
|
| 365 |
+
output_dtype=torch.float,
|
| 366 |
+
compute_dtype=torch.float,
|
| 367 |
+
extra_check=do_not_use_with_small_m_for_int8_woq,
|
| 368 |
+
),
|
| 369 |
+
*generate_gemm_config(
|
| 370 |
+
VecAVX512,
|
| 371 |
+
[
|
| 372 |
+
(4, 32, 64),
|
| 373 |
+
(8, 32, 64),
|
| 374 |
+
],
|
| 375 |
+
input_dtype=torch.bfloat16,
|
| 376 |
+
input2_dtype=torch.int8,
|
| 377 |
+
output_dtype=torch.float,
|
| 378 |
+
compute_dtype=torch.float,
|
| 379 |
+
extra_check=check_int8_woq_small_m_dim,
|
| 380 |
+
),
|
| 381 |
+
*generate_gemm_config(
|
| 382 |
+
VecAVX2,
|
| 383 |
+
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
|
| 384 |
+
input_dtype=torch.float,
|
| 385 |
+
),
|
| 386 |
+
*generate_gemm_config(
|
| 387 |
+
VecAVX2,
|
| 388 |
+
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
|
| 389 |
+
input_dtype=torch.bfloat16,
|
| 390 |
+
output_dtype=torch.float,
|
| 391 |
+
),
|
| 392 |
+
*generate_gemm_config(
|
| 393 |
+
VecAVX2,
|
| 394 |
+
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
|
| 395 |
+
input_dtype=torch.half,
|
| 396 |
+
output_dtype=torch.float,
|
| 397 |
+
),
|
| 398 |
+
*generate_gemm_config(
|
| 399 |
+
VecAVX2,
|
| 400 |
+
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
|
| 401 |
+
input_dtype=torch.bfloat16,
|
| 402 |
+
input2_dtype=torch.int8,
|
| 403 |
+
output_dtype=torch.float,
|
| 404 |
+
compute_dtype=torch.float,
|
| 405 |
+
extra_check=do_not_use_with_small_m_for_int8_woq,
|
| 406 |
+
),
|
| 407 |
+
*generate_gemm_config(
|
| 408 |
+
VecAVX2,
|
| 409 |
+
[
|
| 410 |
+
(2, 16, 64),
|
| 411 |
+
(4, 16, 64),
|
| 412 |
+
],
|
| 413 |
+
input_dtype=torch.bfloat16,
|
| 414 |
+
input2_dtype=torch.int8,
|
| 415 |
+
output_dtype=torch.float,
|
| 416 |
+
compute_dtype=torch.float,
|
| 417 |
+
extra_check=check_int8_woq_small_m_dim,
|
| 418 |
+
),
|
| 419 |
+
*generate_gemm_config(
|
| 420 |
+
VecNEON,
|
| 421 |
+
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
|
| 422 |
+
input_dtype=torch.float,
|
| 423 |
+
input2_dtype=torch.float,
|
| 424 |
+
output_dtype=torch.float,
|
| 425 |
+
compute_dtype=torch.float,
|
| 426 |
+
),
|
| 427 |
+
*generate_gemm_config(
|
| 428 |
+
VecSVE256,
|
| 429 |
+
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
|
| 430 |
+
input_dtype=torch.float,
|
| 431 |
+
input2_dtype=torch.float,
|
| 432 |
+
output_dtype=torch.float,
|
| 433 |
+
compute_dtype=torch.float,
|
| 434 |
+
),
|
| 435 |
+
)
|
| 436 |
+
class CppMicroGemmFP32Vec(CppMicroGemm):
|
| 437 |
+
"""
|
| 438 |
+
This class generates the code for micro gemm using fp32 vec instructions for compute.
|
| 439 |
+
It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output.
|
| 440 |
+
The output of the microkernel is in FP32, but it would be converted to BF16/FP16 in the template,
|
| 441 |
+
if the desired output is BF16/FP16.
|
| 442 |
+
"""
|
| 443 |
+
|
| 444 |
+
TEMPLATE_ENTRY = r"""
|
| 445 |
+
{{declare_kernel}} {
|
| 446 |
+
using Vectorized = at::vec::Vectorized<{{compute_t}}>;
|
| 447 |
+
constexpr auto VLEN = Vectorized::size();
|
| 448 |
+
{{kernel.assert_function}}({{block_n}} % VLEN == 0, "block_n dimension must be multiple of Vector size");
|
| 449 |
+
{{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
|
| 450 |
+
// TODO(jgong5): loop unroll for M and N
|
| 451 |
+
for (int64_t m = 0; m < M; m += {{block_m}}) {
|
| 452 |
+
int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
|
| 453 |
+
for (int64_t n = 0; n < N; n += {{block_n}}) {
|
| 454 |
+
int64_t block_n = std::min<int64_t>(N - n, {{block_n}});
|
| 455 |
+
if (block_m == {{block_m}} && block_n == {{block_n}}) {
|
| 456 |
+
{%- if not trans_b %}
|
| 457 |
+
{{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum, prefetch>(
|
| 458 |
+
{%- else %}
|
| 459 |
+
{{kernel_name}}_transpose_b_kernel<{{block_m}}, {{block_n}}, accum, prefetch>(
|
| 460 |
+
{%- endif %}
|
| 461 |
+
A + m * lda,
|
| 462 |
+
{%- if not trans_b %}
|
| 463 |
+
B + n,
|
| 464 |
+
{%- else %}
|
| 465 |
+
B + n * ldb,
|
| 466 |
+
{%- endif %}
|
| 467 |
+
C + m * ldc + n,
|
| 468 |
+
K,
|
| 469 |
+
lda,
|
| 470 |
+
ldb,
|
| 471 |
+
ldc
|
| 472 |
+
);
|
| 473 |
+
{%- if tail_n %}
|
| 474 |
+
} else if (block_n == {{block_n}}){
|
| 475 |
+
{%- else %}
|
| 476 |
+
} else {
|
| 477 |
+
{%- endif %}
|
| 478 |
+
switch (block_m) {
|
| 479 |
+
{%- for b in range(block_m - 1, 0, -1) %}
|
| 480 |
+
case {{b}}:
|
| 481 |
+
{%- if not trans_b %}
|
| 482 |
+
{{kernel_name}}_kernel<{{b}}, {{block_n}}, accum, prefetch>(
|
| 483 |
+
{%- else %}
|
| 484 |
+
{{kernel_name}}_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>(
|
| 485 |
+
{%- endif %}
|
| 486 |
+
A + m * lda,
|
| 487 |
+
{%- if not trans_b %}
|
| 488 |
+
B + n,
|
| 489 |
+
{%- else %}
|
| 490 |
+
B + n * ldb,
|
| 491 |
+
{%- endif %}
|
| 492 |
+
C + m * ldc + n,
|
| 493 |
+
K,
|
| 494 |
+
lda,
|
| 495 |
+
ldb,
|
| 496 |
+
ldc
|
| 497 |
+
);
|
| 498 |
+
break;
|
| 499 |
+
{%- endfor %}
|
| 500 |
+
default:
|
| 501 |
+
{{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}");
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
{%- if tail_n %}
|
| 505 |
+
} else {
|
| 506 |
+
switch (block_m) {
|
| 507 |
+
{%- for b in range(block_m, 0, -1) %}
|
| 508 |
+
case {{b}}:
|
| 509 |
+
{%- if not trans_b %}
|
| 510 |
+
{{kernel_name}}_ntail_kernel<{{b}}, {{block_n}}, accum, prefetch>(
|
| 511 |
+
{%- else %}
|
| 512 |
+
{{kernel_name}}_ntail_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>(
|
| 513 |
+
{%- endif %}
|
| 514 |
+
A + m * lda,
|
| 515 |
+
{%- if not trans_b %}
|
| 516 |
+
B + n,
|
| 517 |
+
{%- else %}
|
| 518 |
+
B + n * ldb,
|
| 519 |
+
{%- endif %}
|
| 520 |
+
C + m * ldc + n,
|
| 521 |
+
block_n,
|
| 522 |
+
K,
|
| 523 |
+
lda,
|
| 524 |
+
ldb,
|
| 525 |
+
ldc
|
| 526 |
+
);
|
| 527 |
+
break;
|
| 528 |
+
{%- endfor %}
|
| 529 |
+
default:
|
| 530 |
+
{{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}");
|
| 531 |
+
}
|
| 532 |
+
}
|
| 533 |
+
{%- else %}
|
| 534 |
+
}
|
| 535 |
+
{%- endif %}
|
| 536 |
+
}
|
| 537 |
+
}
|
| 538 |
+
}
|
| 539 |
+
"""
|
| 540 |
+
|
| 541 |
+
TEMPLATE_KERNEL = r"""
|
| 542 |
+
|
| 543 |
+
template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum, bool prefetch=false>
|
| 544 |
+
{%- if not trans_b %}
|
| 545 |
+
{%- if tail_n %}
|
| 546 |
+
inline void {{kernel_name}}_ntail_kernel(
|
| 547 |
+
{%- else %}
|
| 548 |
+
inline void {{kernel_name}}_kernel(
|
| 549 |
+
{%- endif %}
|
| 550 |
+
{%- else %}
|
| 551 |
+
{%- if tail_n %}
|
| 552 |
+
inline void {{kernel_name}}_ntail_transpose_b_kernel(
|
| 553 |
+
{%- else %}
|
| 554 |
+
inline void {{kernel_name}}_transpose_b_kernel(
|
| 555 |
+
{%- endif %}
|
| 556 |
+
{%- endif %}
|
| 557 |
+
const {{input_t}}* {{restrict_keyword}} A,
|
| 558 |
+
const {{input2_t}}* {{restrict_keyword}} B,
|
| 559 |
+
{{output_t}}* {{restrict_keyword}} C,
|
| 560 |
+
{%- if tail_n %}
|
| 561 |
+
int64_t N,
|
| 562 |
+
{%- endif %}
|
| 563 |
+
int64_t K,
|
| 564 |
+
int64_t lda,
|
| 565 |
+
int64_t ldb,
|
| 566 |
+
int64_t ldc
|
| 567 |
+
) {
|
| 568 |
+
using Vectorized = at::vec::Vectorized<{{compute_t}}>;
|
| 569 |
+
{%- if input2_dtype in [torch.bfloat16, torch.float16] %}
|
| 570 |
+
using VectorizedIn = at::vec::Vectorized<{{input_t}}>;
|
| 571 |
+
{%- endif %}
|
| 572 |
+
|
| 573 |
+
{%- if not trans_b %}
|
| 574 |
+
constexpr auto VLEN = Vectorized::size();
|
| 575 |
+
constexpr auto ROWS = BLOCK_M;
|
| 576 |
+
constexpr auto COLS = BLOCK_N / VLEN;
|
| 577 |
+
|
| 578 |
+
Vectorized va;
|
| 579 |
+
at::vec::VectorizedN<{{compute_t}}, COLS> vb;
|
| 580 |
+
at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc;
|
| 581 |
+
|
| 582 |
+
{%- if tail_n %}
|
| 583 |
+
int64_t rCOLS = (N + VLEN - 1) / VLEN;
|
| 584 |
+
int ntail = N % VLEN;
|
| 585 |
+
{%- endif %}
|
| 586 |
+
auto loadc = [&](auto i) {
|
| 587 |
+
if constexpr (accum) {
|
| 588 |
+
constexpr int row = i / COLS;
|
| 589 |
+
constexpr int col = i % COLS;
|
| 590 |
+
{%- if tail_n %}
|
| 591 |
+
int load_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN;
|
| 592 |
+
if (col < rCOLS) {
|
| 593 |
+
vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN, load_size);
|
| 594 |
+
}
|
| 595 |
+
{%- else %}
|
| 596 |
+
vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN);
|
| 597 |
+
{%- endif %}
|
| 598 |
+
} else {
|
| 599 |
+
vc[i] = Vectorized(0.0f);
|
| 600 |
+
}
|
| 601 |
+
};
|
| 602 |
+
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
|
| 603 |
+
|
| 604 |
+
auto compute = [&, COLS](auto i, int k) {
|
| 605 |
+
constexpr int row = i / COLS;
|
| 606 |
+
constexpr int col = i % COLS;
|
| 607 |
+
{%- if tail_n %}
|
| 608 |
+
int load_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN;
|
| 609 |
+
{%- endif %}
|
| 610 |
+
if constexpr (col == 0) {
|
| 611 |
+
{%- if alpha != 1 %}
|
| 612 |
+
va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}});
|
| 613 |
+
{%- else %}
|
| 614 |
+
va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]));
|
| 615 |
+
{%- endif %}
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
if constexpr (row == 0) {
|
| 619 |
+
{%- if tail_n %}
|
| 620 |
+
if (col < rCOLS) {
|
| 621 |
+
{%- if input2_dtype in [torch.bfloat16, torch.float16] %}
|
| 622 |
+
auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, load_size);
|
| 623 |
+
vb[col] = at::vec::convert<{{compute_t}}>(b);
|
| 624 |
+
{%- elif input2_dtype == torch.int8 %}
|
| 625 |
+
// Convert VLEN int8 elements to int32, and then fp32
|
| 626 |
+
auto b32 = at::vec::convert_to_int32<int8_t>(B + k * ldb + col * VLEN, load_size);
|
| 627 |
+
vb[col] = at::vec::convert<float>(b32);
|
| 628 |
+
{%- else %}
|
| 629 |
+
vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN, load_size);
|
| 630 |
+
{%- endif %}
|
| 631 |
+
} else {
|
| 632 |
+
vb[col] = Vectorized(0.0f);
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
{%- else %}
|
| 636 |
+
|
| 637 |
+
{%- if input2_dtype in [torch.bfloat16, torch.float16] %}
|
| 638 |
+
auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN);
|
| 639 |
+
vb[col] = at::vec::convert<{{compute_t}}>(b);
|
| 640 |
+
{%- elif input2_dtype == torch.int8 %}
|
| 641 |
+
// Convert VLEN int8 elements to int32, and then fp32
|
| 642 |
+
auto b32 = at::vec::convert_to_int32<int8_t>(B + k * ldb + col * VLEN);
|
| 643 |
+
if constexpr (prefetch) {
|
| 644 |
+
_mm_prefetch(B + (k + {{block_k}}) * ldb + col * VLEN, _MM_HINT_T0);
|
| 645 |
+
}
|
| 646 |
+
vb[col] = at::vec::convert<float>(b32);
|
| 647 |
+
{%- else %}
|
| 648 |
+
vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN);
|
| 649 |
+
{%- endif %}
|
| 650 |
+
{%- endif %}
|
| 651 |
+
|
| 652 |
+
}
|
| 653 |
+
|
| 654 |
+
constexpr int idx = row * COLS + col;
|
| 655 |
+
{%- if tail_n %}
|
| 656 |
+
if (col < rCOLS) {
|
| 657 |
+
vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]);
|
| 658 |
+
}
|
| 659 |
+
{%- else %}
|
| 660 |
+
vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]);
|
| 661 |
+
{%- endif %}
|
| 662 |
+
};
|
| 663 |
+
|
| 664 |
+
for (int k = 0; k < K; ++k) {
|
| 665 |
+
c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
|
| 666 |
+
}
|
| 667 |
+
|
| 668 |
+
// store to C
|
| 669 |
+
auto storec = [&](auto i) {
|
| 670 |
+
constexpr int row = i / COLS;
|
| 671 |
+
constexpr int col = i % COLS;
|
| 672 |
+
{%- if tail_n %}
|
| 673 |
+
int store_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN;
|
| 674 |
+
if (col < rCOLS) {
|
| 675 |
+
vc[i].store(C + row * ldc + col * VLEN, store_size);
|
| 676 |
+
}
|
| 677 |
+
{%- else %}
|
| 678 |
+
vc[i].store(C + row * ldc + col * VLEN);
|
| 679 |
+
{%- endif %}
|
| 680 |
+
};
|
| 681 |
+
c10::ForcedUnroll<ROWS * COLS>{}(storec);
|
| 682 |
+
|
| 683 |
+
{%- else %}
|
| 684 |
+
// Use 2 implementations for the transposed B:
|
| 685 |
+
// First implementation:
|
| 686 |
+
// Transpose first and then perform outer product calculation in sub-blocks,
|
| 687 |
+
// which introduces an additional transpose overhead of [K, N] compared to the non-transpose version.
|
| 688 |
+
// Second implementation:
|
| 689 |
+
// Directly perform inner product calculation in sub-blocks,
|
| 690 |
+
// which introduces an additional vector reduction of [M, N] compared to the non-tranpose version.
|
| 691 |
+
// Therefore, when M * N / (K * N) is large, the first implementation has better performance.
|
| 692 |
+
{%- if tail_n %}
|
| 693 |
+
if (K % Vectorized::size() == 0 && N % Vectorized::size() == 0 && 24 * BLOCK_M > K) {
|
| 694 |
+
{%- else %}
|
| 695 |
+
if (K % Vectorized::size() == 0 && 24 * BLOCK_M > K) {
|
| 696 |
+
{%- endif %}
|
| 697 |
+
// First implementation:
|
| 698 |
+
constexpr auto VLEN = Vectorized::size();
|
| 699 |
+
constexpr auto ROWS = BLOCK_M;
|
| 700 |
+
constexpr auto COLS = BLOCK_N / VLEN;
|
| 701 |
+
int _K = K / VLEN;
|
| 702 |
+
Vectorized va;
|
| 703 |
+
at::vec::VectorizedN<{{compute_t}}, VLEN> vb;
|
| 704 |
+
at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc;
|
| 705 |
+
auto loadc = [&](auto i) {
|
| 706 |
+
if constexpr (accum) {
|
| 707 |
+
constexpr int row = i / COLS;
|
| 708 |
+
constexpr int col = i % COLS;
|
| 709 |
+
vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN);
|
| 710 |
+
} else {
|
| 711 |
+
vc[i] = Vectorized(0.0f);
|
| 712 |
+
}
|
| 713 |
+
};
|
| 714 |
+
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
|
| 715 |
+
auto unroll_loadB = [&](auto i, const {{input2_t}}* {{restrict_keyword}} src_ptr) {
|
| 716 |
+
{%- if input2_dtype in [torch.bfloat16, torch.float16] %}
|
| 717 |
+
auto b = VectorizedIn::loadu(src_ptr + i * ldb, VLEN);
|
| 718 |
+
vb[i] = at::vec::convert<{{compute_t}}>(b);
|
| 719 |
+
{%- elif input2_dtype == torch.int8 %}
|
| 720 |
+
auto b32 = at::vec::convert_to_int32<int8_t>(src_ptr + i * ldb, VLEN);
|
| 721 |
+
vb[i] = at::vec::convert<float>(b32);
|
| 722 |
+
{%- else %}
|
| 723 |
+
vb[i] = Vectorized::loadu(src_ptr + i * ldb, VLEN);
|
| 724 |
+
{%- endif %}
|
| 725 |
+
};
|
| 726 |
+
auto compute_trans = [&, COLS](auto i, int k) {
|
| 727 |
+
constexpr int row = i % ROWS;
|
| 728 |
+
constexpr int col = i / ROWS;
|
| 729 |
+
constexpr int e_col = col * VLEN;
|
| 730 |
+
int idk = k * VLEN;
|
| 731 |
+
if constexpr (row == 0) {
|
| 732 |
+
c10::ForcedUnroll<VLEN>{}(unroll_loadB, B + e_col * ldb + idk);
|
| 733 |
+
at::vec::transpose_block(vb);
|
| 734 |
+
}
|
| 735 |
+
constexpr int idx = row * COLS + col;
|
| 736 |
+
{{kernel.unroll_pragma(16)}}
|
| 737 |
+
for (int j = 0; j < VLEN; j++) {
|
| 738 |
+
{%- if alpha != 1 %}
|
| 739 |
+
va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + idk + j]) * {{alpha}});
|
| 740 |
+
{%- else %}
|
| 741 |
+
va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + idk + j]));
|
| 742 |
+
{%- endif %}
|
| 743 |
+
vc[idx] = at::vec::fmadd(va, vb[j], vc[idx]);
|
| 744 |
+
}
|
| 745 |
+
};
|
| 746 |
+
for (int k = 0; k < _K; ++k) {
|
| 747 |
+
c10::ForcedUnroll<ROWS * COLS>{}(compute_trans, k);
|
| 748 |
+
}
|
| 749 |
+
// store to C
|
| 750 |
+
auto storec = [&](auto i) {
|
| 751 |
+
constexpr int row = i / COLS;
|
| 752 |
+
constexpr int col = i % COLS;
|
| 753 |
+
vc[i].store(C + row * ldc + col * VLEN);
|
| 754 |
+
};
|
| 755 |
+
c10::ForcedUnroll<ROWS * COLS>{}(storec);
|
| 756 |
+
} else {
|
| 757 |
+
// Second implementation
|
| 758 |
+
{%- if input2_dtype in [torch.bfloat16, torch.float16] %}
|
| 759 |
+
constexpr auto VLEN = VectorizedIn::size();
|
| 760 |
+
{%- else %}
|
| 761 |
+
constexpr auto VLEN = Vectorized::size();
|
| 762 |
+
{%- endif %}
|
| 763 |
+
int _K = (K + VLEN - 1) / VLEN;
|
| 764 |
+
// sub-block size of BLOCK_N and BLOCK_M
|
| 765 |
+
constexpr int sM = {{sub_block_m}};
|
| 766 |
+
constexpr int sN = {{sub_block_n}};
|
| 767 |
+
{%- if tail_n %}
|
| 768 |
+
int bN = (N + sN - 1) / sN;
|
| 769 |
+
{%- else %}
|
| 770 |
+
constexpr int bN = (BLOCK_N + sN - 1) / sN;
|
| 771 |
+
{%- endif %}
|
| 772 |
+
constexpr int bM = (BLOCK_M + sM - 1) / sM;
|
| 773 |
+
|
| 774 |
+
{%- if input2_dtype in [torch.bfloat16, torch.float16] %}
|
| 775 |
+
at::vec::VectorizedN<{{compute_t}}, 2> va;
|
| 776 |
+
at::vec::VectorizedN<{{compute_t}}, 2 * sN> vb;
|
| 777 |
+
{%- else %}
|
| 778 |
+
at::vec::Vectorized<{{compute_t}}> va;
|
| 779 |
+
at::vec::VectorizedN<{{compute_t}}, sN> vb;
|
| 780 |
+
{%- endif %}
|
| 781 |
+
at::vec::VectorizedN<{{compute_t}}, sN * sM> vmid;
|
| 782 |
+
|
| 783 |
+
{%- if tail_n %}
|
| 784 |
+
int ntail = N % sN;
|
| 785 |
+
{%- else %}
|
| 786 |
+
constexpr int ntail = BLOCK_N % sN;
|
| 787 |
+
{%- endif %}
|
| 788 |
+
constexpr int mtail = BLOCK_M % sM;
|
| 789 |
+
int ktail = K % VLEN;
|
| 790 |
+
|
| 791 |
+
auto compute_trans = [&](int m, int n, int k) {
|
| 792 |
+
{%- if tail_n %}
|
| 793 |
+
int e_n = (n == bN - 1 && ntail != 0) ? (N - n * sN) : sN;
|
| 794 |
+
{%- else %}
|
| 795 |
+
int e_n = (n == bN - 1 && ntail != 0) ? (BLOCK_N - n * sN) : sN;
|
| 796 |
+
{%- endif %}
|
| 797 |
+
int e_m = (m == bM - 1 && mtail != 0) ? (BLOCK_M - m * sM) : sM;
|
| 798 |
+
int e_k = (k == _K - 1 && ktail != 0) ? (K - k * VLEN) : VLEN;
|
| 799 |
+
{{kernel.unroll_pragma(sub_block_n)}}
|
| 800 |
+
for (int i = 0; i < e_n; i++) {
|
| 801 |
+
{%- if input2_dtype in [torch.bfloat16, torch.float16] %}
|
| 802 |
+
auto b = VectorizedIn::loadu(B + (sN * n + i) * ldb + k * VLEN, e_k);
|
| 803 |
+
std::tie(vb[2 * i], vb[2 * i + 1]) = at::vec::convert_to_float<{{input_t}}>(b);
|
| 804 |
+
{%- elif input2_dtype == torch.int8 %}
|
| 805 |
+
auto b32 = at::vec::convert_to_int32<int8_t>(B + (sN * n + i) * ldb + k * VLEN, e_k);
|
| 806 |
+
vb[i] = at::vec::convert<float>(b32);
|
| 807 |
+
{%- else %}
|
| 808 |
+
vb[i] = Vectorized::loadu(B + (sN * n + i) * ldb + k * VLEN, e_k);
|
| 809 |
+
{%- endif %}
|
| 810 |
+
}
|
| 811 |
+
|
| 812 |
+
{{kernel.unroll_pragma(sub_block_m)}}
|
| 813 |
+
for (int s = 0; s < e_m; s++) {
|
| 814 |
+
{%- if input2_dtype in [torch.bfloat16, torch.float16] %}
|
| 815 |
+
auto a = VectorizedIn::loadu(A + (sM * m + s) * lda + k * VLEN, e_k);
|
| 816 |
+
std::tie(va[0], va[1]) = at::vec::convert_to_float<{{input_t}}>(a);
|
| 817 |
+
{%- elif input2_dtype == torch.int8 %}
|
| 818 |
+
auto a32 = at::vec::convert_to_int32<int8_t>(A + (sM * m + s) * lda + k * VLEN, e_k);
|
| 819 |
+
va = at::vec::convert<float>(a32);
|
| 820 |
+
{%- else %}
|
| 821 |
+
va = Vectorized::loadu(A + (sM * m + s) * lda + k * VLEN, e_k);
|
| 822 |
+
{%- endif %}
|
| 823 |
+
|
| 824 |
+
{%- if alpha != 1 %}
|
| 825 |
+
va = va * Vectorized({{alpha}});
|
| 826 |
+
{%- endif %}
|
| 827 |
+
if (k == 0) {
|
| 828 |
+
{{kernel.unroll_pragma(sub_block_n)}}
|
| 829 |
+
for (int i = 0; i < e_n; i++) {
|
| 830 |
+
{%- if input2_dtype in [torch.bfloat16, torch.float16] %}
|
| 831 |
+
vmid[sN * s + i] = at::vec::fmadd(va[0], vb[2 * i], Vectorized(0.0f));
|
| 832 |
+
vmid[sN * s + i] = at::vec::fmadd(va[1], vb[2 * i + 1], vmid[sN * s + i]);
|
| 833 |
+
{%- else %}
|
| 834 |
+
vmid[sN * s + i] = at::vec::fmadd(va, vb[i], Vectorized(0.0f));
|
| 835 |
+
{%- endif %}
|
| 836 |
+
}
|
| 837 |
+
} else {
|
| 838 |
+
{{kernel.unroll_pragma(sub_block_n)}}
|
| 839 |
+
for (int i = 0; i < e_n; i++) {
|
| 840 |
+
{%- if input2_dtype in [torch.bfloat16, torch.float16] %}
|
| 841 |
+
vmid[sN * s + i] = at::vec::fmadd(va[0], vb[2 * i], vmid[sN * s + i]);
|
| 842 |
+
vmid[sN * s + i] = at::vec::fmadd(va[1], vb[2 * i + 1], vmid[sN * s + i]);
|
| 843 |
+
{%- else %}
|
| 844 |
+
vmid[sN * s + i] = at::vec::fmadd(va, vb[i], vmid[sN * s + i]);
|
| 845 |
+
{%- endif %}
|
| 846 |
+
}
|
| 847 |
+
}
|
| 848 |
+
}
|
| 849 |
+
|
| 850 |
+
// store to C
|
| 851 |
+
if (k == _K - 1) {
|
| 852 |
+
{{kernel.unroll_pragma(sub_block_m)}}
|
| 853 |
+
for (int s = 0; s < e_m; s++) {
|
| 854 |
+
{{kernel.unroll_pragma(sub_block_n)}}
|
| 855 |
+
for (int i = 0; i < e_n; i++) {
|
| 856 |
+
auto v = at::vec::vec_reduce_all([](Vectorized& x, Vectorized& y) { return x + y; }, vmid[sN * s + i]);
|
| 857 |
+
if constexpr (accum) {
|
| 858 |
+
auto c = *(C + (sM * m + s) * ldc + sN * n + i);
|
| 859 |
+
*(C + (sM * m + s) * ldc + sN * n + i) = c + v;
|
| 860 |
+
} else {
|
| 861 |
+
*(C + (sM * m + s) * ldc + sN * n + i) = v;
|
| 862 |
+
}
|
| 863 |
+
}
|
| 864 |
+
}
|
| 865 |
+
}
|
| 866 |
+
};
|
| 867 |
+
|
| 868 |
+
for (int n = 0; n < bN; ++n) {
|
| 869 |
+
for (int m = 0; m < bM; ++m) {
|
| 870 |
+
for (int k = 0; k < _K; ++k) {
|
| 871 |
+
compute_trans(m, n, k);
|
| 872 |
+
}
|
| 873 |
+
}
|
| 874 |
+
}
|
| 875 |
+
}
|
| 876 |
+
{%- endif %}
|
| 877 |
+
}
|
| 878 |
+
"""
|
| 879 |
+
|
| 880 |
+
# set trans_b to generate gemm that supports transposed B matrix
|
| 881 |
+
# set tail_n to support the tail of N
|
| 882 |
+
# TODO add trans_b support for other micro gemms
|
| 883 |
+
# and move setting of trans_b to the init of CppMicroGemm
|
| 884 |
+
def __init__(
|
| 885 |
+
self,
|
| 886 |
+
name,
|
| 887 |
+
input_dtype,
|
| 888 |
+
input2_dtype,
|
| 889 |
+
output_dtype,
|
| 890 |
+
compute_dtype,
|
| 891 |
+
register_blocking,
|
| 892 |
+
alpha=1,
|
| 893 |
+
tail_n=False,
|
| 894 |
+
trans_b=False,
|
| 895 |
+
) -> None:
|
| 896 |
+
super().__init__(
|
| 897 |
+
name,
|
| 898 |
+
input_dtype,
|
| 899 |
+
input2_dtype,
|
| 900 |
+
output_dtype,
|
| 901 |
+
compute_dtype,
|
| 902 |
+
register_blocking,
|
| 903 |
+
alpha,
|
| 904 |
+
)
|
| 905 |
+
self.tail_n = tail_n
|
| 906 |
+
# trans_b is only supported on platforms that
|
| 907 |
+
# support avx512 or avx2 since transpose_block is
|
| 908 |
+
# only implemented on these platforms
|
| 909 |
+
if trans_b:
|
| 910 |
+
vec_isa = pick_vec_isa()
|
| 911 |
+
assert issubclass(vec_isa.__class__, VecAVX512) or issubclass(
|
| 912 |
+
vec_isa.__class__, VecAVX2
|
| 913 |
+
)
|
| 914 |
+
self.trans_b = trans_b
|
| 915 |
+
|
| 916 |
+
def codegen_define(self, kernel: CppTemplateKernel) -> str:
|
| 917 |
+
options = {
|
| 918 |
+
"declare_kernel": self.get_kernel_declaration(),
|
| 919 |
+
"kernel": kernel,
|
| 920 |
+
"block_m": self.register_blocking.block_m,
|
| 921 |
+
"block_n": self.register_blocking.block_n,
|
| 922 |
+
"block_k": self.register_blocking.block_k,
|
| 923 |
+
"trans_b": False,
|
| 924 |
+
"tail_n": False,
|
| 925 |
+
"restrict_keyword": get_restrict_keyword(),
|
| 926 |
+
**self.get_common_options(),
|
| 927 |
+
}
|
| 928 |
+
if self.trans_b:
|
| 929 |
+
# TODO supports tuning of sub_block_m/sub_block_n
|
| 930 |
+
# to get better performance for specific shapes
|
| 931 |
+
sub_block_m = min(1, self.register_blocking.block_m)
|
| 932 |
+
sub_block_n = min(4, self.register_blocking.block_n)
|
| 933 |
+
# update options to generate kernel with trans_b and sub-block size
|
| 934 |
+
options.update(
|
| 935 |
+
{
|
| 936 |
+
"trans_b": self.trans_b,
|
| 937 |
+
"sub_block_m": sub_block_m,
|
| 938 |
+
"sub_block_n": sub_block_n,
|
| 939 |
+
}
|
| 940 |
+
)
|
| 941 |
+
result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
|
| 942 |
+
options
|
| 943 |
+
)
|
| 944 |
+
# update options to generate the kernel for the tail of N
|
| 945 |
+
if self.tail_n:
|
| 946 |
+
options.update(
|
| 947 |
+
{
|
| 948 |
+
"tail_n": self.tail_n,
|
| 949 |
+
}
|
| 950 |
+
)
|
| 951 |
+
result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
|
| 952 |
+
options
|
| 953 |
+
)
|
| 954 |
+
result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
|
| 955 |
+
options
|
| 956 |
+
)
|
| 957 |
+
return result
|
| 958 |
+
|
| 959 |
+
|
| 960 |
+
# extra check for CppMicroGemmAMX
|
| 961 |
+
def check_amx_extra(config, m, n, k, alpha, num_threads, **kwargs):
|
| 962 |
+
vnni_size = 4 if config.input_dtype in [torch.uint8, torch.int8] else 2
|
| 963 |
+
return k % vnni_size == 0 and alpha == 1
|
| 964 |
+
|
| 965 |
+
|
| 966 |
+
@register_micro_gemm(
|
| 967 |
+
*generate_gemm_config(
|
| 968 |
+
VecAMX,
|
| 969 |
+
[(32, 32, 64), (48, 16, 64)],
|
| 970 |
+
input_dtype=torch.int8,
|
| 971 |
+
input2_dtype=torch.int8,
|
| 972 |
+
output_dtype=torch.int32,
|
| 973 |
+
compute_dtype=torch.int32,
|
| 974 |
+
extra_check=check_amx_extra,
|
| 975 |
+
),
|
| 976 |
+
*generate_gemm_config(
|
| 977 |
+
VecAMX,
|
| 978 |
+
[(32, 32, 32), (48, 16, 32), (16, 48, 32)],
|
| 979 |
+
input_dtype=torch.bfloat16,
|
| 980 |
+
input2_dtype=torch.int8,
|
| 981 |
+
output_dtype=torch.float,
|
| 982 |
+
compute_dtype=torch.float,
|
| 983 |
+
extra_check=check_amx_extra,
|
| 984 |
+
),
|
| 985 |
+
*generate_gemm_config(
|
| 986 |
+
VecAMX,
|
| 987 |
+
[(32, 32, 32), (48, 16, 32), (16, 48, 32)],
|
| 988 |
+
input_dtype=torch.bfloat16,
|
| 989 |
+
output_dtype=torch.float,
|
| 990 |
+
extra_check=check_amx_extra,
|
| 991 |
+
),
|
| 992 |
+
*generate_gemm_config(
|
| 993 |
+
VecAMX,
|
| 994 |
+
[(32, 32, 64), (48, 16, 64)],
|
| 995 |
+
input_dtype=torch.uint8,
|
| 996 |
+
input2_dtype=torch.int8,
|
| 997 |
+
output_dtype=torch.int32,
|
| 998 |
+
compute_dtype=torch.int32,
|
| 999 |
+
extra_check=check_amx_extra,
|
| 1000 |
+
),
|
| 1001 |
+
)
|
| 1002 |
+
class CppMicroGemmAMX(CppMicroGemm):
|
| 1003 |
+
"""
|
| 1004 |
+
This class generates the code for micro gemm using Advanced Matrix extension (AMX)
|
| 1005 |
+
instructions available in 4th generation Intel Xeon for compute.
|
| 1006 |
+
It supports input types of torch.bfloat16 with fp32 output.
|
| 1007 |
+
"""
|
| 1008 |
+
|
| 1009 |
+
TEMPLATE_ENTRY = r"""
|
| 1010 |
+
{{declare_kernel}} {
|
| 1011 |
+
{{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
|
| 1012 |
+
{{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2");
|
| 1013 |
+
{%- if pack_vnni_B_locally %}
|
| 1014 |
+
{{template.codegen_allocate_weight_buffer("packed_B_buf", input2_t, "K", block_n)}}
|
| 1015 |
+
{%- endif %}
|
| 1016 |
+
{%- if use_cached_dequantized_B %}
|
| 1017 |
+
// Create a stack-allocated buffer for tiles of B.
|
| 1018 |
+
// Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements.
|
| 1019 |
+
// we cache K * {{block_n}} elements of dequantized B
|
| 1020 |
+
{{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}}
|
| 1021 |
+
const auto buf_size = K * {{block_n}};
|
| 1022 |
+
auto load_dequantized_B = [&](int base_idx) {
|
| 1023 |
+
// Load a tile of B & cache it in L1D.
|
| 1024 |
+
{{input2_t}}* base_addr = const_cast<{{input2_t}}*>(B) + base_idx;
|
| 1025 |
+
for (int idx_dq = 0, idx_q = 0; idx_dq < buf_size; idx_q += ldb, idx_dq += {{block_n}}) {
|
| 1026 |
+
{%- for vec_idx in range(0, block_n, 32) %}
|
| 1027 |
+
{%- if (block_n - vec_idx) >= 32 %}
|
| 1028 |
+
auto b_int8_idx_{{vec_idx}} = at::vec::Vectorized<int8_t>::loadu(
|
| 1029 |
+
base_addr + idx_q + {{vec_idx}} ,
|
| 1030 |
+
static_cast<int64_t>(32)
|
| 1031 |
+
);
|
| 1032 |
+
auto b_bf16_idx_{{vec_idx}} = at::vec::convert<{{input_t}}>(b_int8_idx_{{vec_idx}});
|
| 1033 |
+
b_bf16_idx_{{vec_idx}}.store(dequantized_B_buf + idx_dq + {{vec_idx}});
|
| 1034 |
+
{%- else %}
|
| 1035 |
+
auto b_int8_tail = at::vec::Vectorized<int8_t>::loadu(
|
| 1036 |
+
base_addr + idx_q + {{block_n - (block_n % 32)}},
|
| 1037 |
+
static_cast<int64_t>({{block_n % 32}})
|
| 1038 |
+
);
|
| 1039 |
+
auto b_bf16_tail = at::vec::convert<{{input_t}}>(b_int8_tail);
|
| 1040 |
+
b_bf16_tail.store(
|
| 1041 |
+
dequantized_B_buf + idx_dq + {{block_n - (block_n % 32)}},
|
| 1042 |
+
static_cast<int64_t>({{block_n % 32}})
|
| 1043 |
+
);
|
| 1044 |
+
{%- endif %}
|
| 1045 |
+
{%- endfor %}
|
| 1046 |
+
}
|
| 1047 |
+
};
|
| 1048 |
+
{%- endif %}
|
| 1049 |
+
// The ldb would not be block_n if N != block_n
|
| 1050 |
+
{%- if use_cached_dequantized_B or pack_vnni_B_locally %}
|
| 1051 |
+
const int64_t updated_ldb = {{block_n}};
|
| 1052 |
+
{%- else %}
|
| 1053 |
+
const int64_t updated_ldb = ldb;
|
| 1054 |
+
{%- endif %}
|
| 1055 |
+
// TODO(jgong5): loop unroll for M and N
|
| 1056 |
+
for (int64_t n = 0; n < N; n += {{block_n}}) {
|
| 1057 |
+
{%- if pack_vnni_B_locally %}
|
| 1058 |
+
// Pack non-constant weights into VNNI interleaved format in packed_B_buf
|
| 1059 |
+
at::vec::pack_vnni2(B + n, packed_B_buf, ldb, K, {{block_n}});
|
| 1060 |
+
{%- elif use_cached_dequantized_B %}
|
| 1061 |
+
// Dequantize K * block_n int8 B elements into BF16
|
| 1062 |
+
load_dequantized_B(n);
|
| 1063 |
+
{%- endif %}
|
| 1064 |
+
for (int64_t m = 0; m < M; m += {{block_m}}) {
|
| 1065 |
+
int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
|
| 1066 |
+
int64_t m_tail = m;
|
| 1067 |
+
{%- for num_rows in range(block_m, 0, -16) %}
|
| 1068 |
+
{%- if num_rows != block_m %}
|
| 1069 |
+
else
|
| 1070 |
+
{%- endif %}
|
| 1071 |
+
if (block_m >= {{num_rows}}) {
|
| 1072 |
+
{{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}<accum>(
|
| 1073 |
+
amx_state,
|
| 1074 |
+
A + m * lda,
|
| 1075 |
+
{%- if use_cached_dequantized_B %}
|
| 1076 |
+
dequantized_B_buf,
|
| 1077 |
+
{%- elif pack_vnni_B_locally %}
|
| 1078 |
+
packed_B_buf,
|
| 1079 |
+
{%- else %}
|
| 1080 |
+
B + n,
|
| 1081 |
+
{%- endif %}
|
| 1082 |
+
C + m * ldc + n,
|
| 1083 |
+
K,
|
| 1084 |
+
lda,
|
| 1085 |
+
updated_ldb,
|
| 1086 |
+
ldc,
|
| 1087 |
+
16
|
| 1088 |
+
);
|
| 1089 |
+
block_m -= {{num_rows}};
|
| 1090 |
+
m_tail += {{num_rows}};
|
| 1091 |
+
}
|
| 1092 |
+
{%- endfor %}
|
| 1093 |
+
if (block_m > 0) {
|
| 1094 |
+
{{kernel_name}}_amx_kernel_16_{{num_columns}}<accum>(
|
| 1095 |
+
amx_state,
|
| 1096 |
+
A + m_tail * lda,
|
| 1097 |
+
{%- if use_cached_dequantized_B %}
|
| 1098 |
+
dequantized_B_buf,
|
| 1099 |
+
{%- elif pack_vnni_B_locally %}
|
| 1100 |
+
packed_B_buf,
|
| 1101 |
+
{%- else %}
|
| 1102 |
+
B + n,
|
| 1103 |
+
{%- endif %}
|
| 1104 |
+
C + m_tail * ldc + n,
|
| 1105 |
+
K,
|
| 1106 |
+
lda,
|
| 1107 |
+
updated_ldb,
|
| 1108 |
+
ldc,
|
| 1109 |
+
block_m
|
| 1110 |
+
);
|
| 1111 |
+
}
|
| 1112 |
+
}
|
| 1113 |
+
}
|
| 1114 |
+
}
|
| 1115 |
+
"""
|
| 1116 |
+
|
| 1117 |
+
TEMPLATE_KERNEL = r"""
|
| 1118 |
+
|
| 1119 |
+
template <bool accum, bool prefetch=false>
|
| 1120 |
+
inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}(
|
| 1121 |
+
AMXState& amx_state,
|
| 1122 |
+
const {{input_t}}* {{restrict_keyword}} A,
|
| 1123 |
+
{%- if use_cached_dequantized_B %}
|
| 1124 |
+
const {{input_t}}* {{restrict_keyword}} B,
|
| 1125 |
+
{%- else %}
|
| 1126 |
+
const {{input2_t}}* {{restrict_keyword}} B,
|
| 1127 |
+
{%- endif %}
|
| 1128 |
+
{{output_t}}* {{restrict_keyword}} C,
|
| 1129 |
+
int64_t K,
|
| 1130 |
+
int64_t lda,
|
| 1131 |
+
int64_t ldb,
|
| 1132 |
+
int64_t ldc,
|
| 1133 |
+
uint8_t tilecfg_rows
|
| 1134 |
+
) {
|
| 1135 |
+
// TODO(jgong5): add prefetch hint for A, B, C
|
| 1136 |
+
auto loadconfig = [](const amx_tilecfg& cfg) {
|
| 1137 |
+
_tile_loadconfig(&cfg);
|
| 1138 |
+
};
|
| 1139 |
+
const auto last_k_offset = K / {{block_k}} * {{block_k}};
|
| 1140 |
+
const auto tail_k_size = K - last_k_offset;
|
| 1141 |
+
if C10_LIKELY (last_k_offset > 0) {
|
| 1142 |
+
amx_state.configure(tilecfg_rows, 64, {{num_rows}} / 16, {{num_columns}}, loadconfig);
|
| 1143 |
+
} else {
|
| 1144 |
+
amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig);
|
| 1145 |
+
}
|
| 1146 |
+
auto load_c = [&]() {
|
| 1147 |
+
{%- for tile_row in range(num_rows // 16) %}
|
| 1148 |
+
{%- for tile_col in range(num_columns) %}
|
| 1149 |
+
{%- set tile_idx = tile_row * num_columns + tile_col %}
|
| 1150 |
+
_tile_loadd({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}}));
|
| 1151 |
+
{%- endfor %}
|
| 1152 |
+
{%- endfor %}
|
| 1153 |
+
};
|
| 1154 |
+
auto zero_c = [&]() {
|
| 1155 |
+
{%- for tile_row in range(num_rows // 16) %}
|
| 1156 |
+
{%- for tile_col in range(num_columns) %}
|
| 1157 |
+
{%- set tile_idx = tile_row * num_columns + tile_col %}
|
| 1158 |
+
_tile_zero({{tile_idx}});
|
| 1159 |
+
{%- endfor %}
|
| 1160 |
+
{%- endfor %}
|
| 1161 |
+
};
|
| 1162 |
+
|
| 1163 |
+
if constexpr (accum) {
|
| 1164 |
+
load_c();
|
| 1165 |
+
} else {
|
| 1166 |
+
zero_c();
|
| 1167 |
+
}
|
| 1168 |
+
|
| 1169 |
+
auto compute = [&](int k) {
|
| 1170 |
+
{%- set tile_offset_a = num_rows // 16 * num_columns %}
|
| 1171 |
+
{%- set tile_offset_b = tile_offset_a + num_rows // 16 %}
|
| 1172 |
+
{%- for tile_row in range(num_rows // 16) %}
|
| 1173 |
+
{%- for tile_col in range(num_columns) %}
|
| 1174 |
+
{%- set tile_idx_a = tile_offset_a + tile_row %}
|
| 1175 |
+
{%- set tile_idx_b = tile_offset_b + tile_col %}
|
| 1176 |
+
{%- set tile_idx_c = tile_row * num_columns + tile_col %}
|
| 1177 |
+
{%- if tile_col == 0 %}
|
| 1178 |
+
_tile_stream_loadd({{tile_idx_a}}, A + {{tile_row * 16}} * lda + k, lda * sizeof({{input_t}}));
|
| 1179 |
+
{%- endif %}
|
| 1180 |
+
{%- if tile_row == 0 %}
|
| 1181 |
+
_tile_loadd({{tile_idx_b}}, B + k * ldb + {{tile_col * 16 * vnni_size}}, ldb * {{vnni_size}} * sizeof({{input_t}}));
|
| 1182 |
+
{%- endif %}
|
| 1183 |
+
{%- if int8_gemm %}
|
| 1184 |
+
{%- if input_dtype == torch.int8 %}
|
| 1185 |
+
_tile_dpbssd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
|
| 1186 |
+
{%- else %}
|
| 1187 |
+
_tile_dpbusd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
|
| 1188 |
+
{%- endif %}
|
| 1189 |
+
{%- else %}
|
| 1190 |
+
_tile_dpbf16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
|
| 1191 |
+
{%- endif %}
|
| 1192 |
+
{%- endfor %}
|
| 1193 |
+
{%- endfor %}
|
| 1194 |
+
};
|
| 1195 |
+
|
| 1196 |
+
{{kernel.unroll_pragma(4)}}
|
| 1197 |
+
for (int k = 0; k < last_k_offset; k += {{block_k}}) {
|
| 1198 |
+
compute(k);
|
| 1199 |
+
}
|
| 1200 |
+
|
| 1201 |
+
auto store_c = [&]() {
|
| 1202 |
+
// store to C
|
| 1203 |
+
{%- for tile_row in range(num_rows // 16) %}
|
| 1204 |
+
{%- for tile_col in range(num_columns) %}
|
| 1205 |
+
{%- set tile_idx = tile_row * num_columns + tile_col %}
|
| 1206 |
+
_tile_stored({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}}));
|
| 1207 |
+
{%- endfor %}
|
| 1208 |
+
{%- endfor %}
|
| 1209 |
+
};
|
| 1210 |
+
|
| 1211 |
+
// TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
|
| 1212 |
+
if C10_UNLIKELY (tail_k_size > 0) {
|
| 1213 |
+
if C10_LIKELY (last_k_offset > 0) {
|
| 1214 |
+
store_c();
|
| 1215 |
+
amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig);
|
| 1216 |
+
load_c();
|
| 1217 |
+
}
|
| 1218 |
+
compute(last_k_offset);
|
| 1219 |
+
}
|
| 1220 |
+
|
| 1221 |
+
store_c();
|
| 1222 |
+
}
|
| 1223 |
+
"""
|
| 1224 |
+
|
| 1225 |
+
def codegen_define(self, kernel: CppTemplateKernel) -> str:
|
| 1226 |
+
block_m, block_n, block_k = self.register_blocking
|
| 1227 |
+
assert block_m % 16 == 0, "Only support block_m % 16 == 0 for AMX"
|
| 1228 |
+
assert block_n % 16 == 0, "Only support block_n % 16 == 0 for AMX"
|
| 1229 |
+
if self.input_dtype in [torch.uint8, torch.int8]:
|
| 1230 |
+
assert block_k == 64, "Only support block_k = 64 for AMX INT8"
|
| 1231 |
+
else:
|
| 1232 |
+
assert block_k == 32, "Only support block_k = 32 for AMX Bfloat16/Float16"
|
| 1233 |
+
num_columns = block_n // 16
|
| 1234 |
+
options = {
|
| 1235 |
+
"declare_kernel": self.get_kernel_declaration(),
|
| 1236 |
+
"use_cached_dequantized_B": (
|
| 1237 |
+
self.input_dtype == torch.bfloat16
|
| 1238 |
+
and self.input2_dtype in [torch.int8, torch.uint8]
|
| 1239 |
+
),
|
| 1240 |
+
"kernel": kernel,
|
| 1241 |
+
"block_m": block_m,
|
| 1242 |
+
"block_n": block_n,
|
| 1243 |
+
"block_k": block_k,
|
| 1244 |
+
"num_columns": num_columns,
|
| 1245 |
+
"restrict_keyword": get_restrict_keyword(),
|
| 1246 |
+
**self.get_common_options(),
|
| 1247 |
+
}
|
| 1248 |
+
result = ""
|
| 1249 |
+
for num_rows in range(block_m, 0, -16):
|
| 1250 |
+
amx_kernel_options = {**options, "num_rows": num_rows}
|
| 1251 |
+
result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
|
| 1252 |
+
amx_kernel_options
|
| 1253 |
+
)
|
| 1254 |
+
result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
|
| 1255 |
+
options
|
| 1256 |
+
)
|
| 1257 |
+
return result
|
| 1258 |
+
|
| 1259 |
+
def codegen_init(
|
| 1260 |
+
self,
|
| 1261 |
+
kernel: CppTemplateKernel,
|
| 1262 |
+
) -> str:
|
| 1263 |
+
return "AMXState amx_state;"
|
| 1264 |
+
|
| 1265 |
+
def codegen_finalize(
|
| 1266 |
+
self,
|
| 1267 |
+
kernel: CppTemplateKernel,
|
| 1268 |
+
) -> str:
|
| 1269 |
+
return "amx_state.release([]() { _tile_release(); });"
|
| 1270 |
+
|
| 1271 |
+
def get_kernel_extra_args_declare(self) -> str:
|
| 1272 |
+
return "AMXState& amx_state,"
|
| 1273 |
+
|
| 1274 |
+
def get_kernel_extra_args(self, **kwargs) -> list[str]:
|
| 1275 |
+
return ["amx_state,"]
|
| 1276 |
+
|
| 1277 |
+
def get_b_layout(self):
|
| 1278 |
+
if self.input_dtype in [torch.uint8, torch.int8]:
|
| 1279 |
+
return LayoutType.VNNI4
|
| 1280 |
+
else:
|
| 1281 |
+
return LayoutType.VNNI2
|
| 1282 |
+
|
| 1283 |
+
|
| 1284 |
+
# extra check for CppMicroBrgemm
|
| 1285 |
+
def check_brgemm_extra(config, m, n, k, alpha, num_threads, **kwargs):
|
| 1286 |
+
assert config.input_dtype == torch.half and config.output_dtype == torch.float
|
| 1287 |
+
vnni_size = 2
|
| 1288 |
+
# use brgemm for Half when amx_fp16 is supported
|
| 1289 |
+
return torch.cpu._is_amx_fp16_supported() and k % vnni_size == 0 and alpha == 1
|
| 1290 |
+
|
| 1291 |
+
|
| 1292 |
+
@register_micro_gemm(
|
| 1293 |
+
*generate_gemm_config(
|
| 1294 |
+
VecAMX,
|
| 1295 |
+
[(32, 32, 32), (48, 16, 32), (16, 48, 32)],
|
| 1296 |
+
input_dtype=torch.half,
|
| 1297 |
+
output_dtype=torch.float,
|
| 1298 |
+
extra_check=check_brgemm_extra,
|
| 1299 |
+
),
|
| 1300 |
+
)
|
| 1301 |
+
class CppMicroBrgemm(CppMicroGemm):
|
| 1302 |
+
"""
|
| 1303 |
+
This class generates the code for micro gemm using oneDNN brgemm.
|
| 1304 |
+
It supports input types of torch.half.
|
| 1305 |
+
"""
|
| 1306 |
+
|
| 1307 |
+
TEMPLATE_ENTRY = r"""
|
| 1308 |
+
#include <ATen/native/CPUBlas.h>
|
| 1309 |
+
{{declare_kernel}} {
|
| 1310 |
+
{%- if pack_vnni_B_locally %}
|
| 1311 |
+
{{template.codegen_allocate_weight_buffer("packed_B_buf", input2_t, "K * N")}}
|
| 1312 |
+
at::vec::pack_vnni2(B, packed_B_buf, ldb, K, N);
|
| 1313 |
+
{%- endif %}
|
| 1314 |
+
at::native::cpublas::brgemm(
|
| 1315 |
+
M, N, K,
|
| 1316 |
+
{%- if pack_vnni_B_locally %}
|
| 1317 |
+
lda, N, ldc,
|
| 1318 |
+
{%- else %}
|
| 1319 |
+
lda, ldb, ldc,
|
| 1320 |
+
{%- endif %}
|
| 1321 |
+
accum,
|
| 1322 |
+
A,
|
| 1323 |
+
{%- if pack_vnni_B_locally %}
|
| 1324 |
+
packed_B_buf,
|
| 1325 |
+
{%- else %}
|
| 1326 |
+
B,
|
| 1327 |
+
{%- endif %}
|
| 1328 |
+
C);
|
| 1329 |
+
}
|
| 1330 |
+
"""
|
| 1331 |
+
|
| 1332 |
+
def codegen_define(self, kernel: CppTemplateKernel) -> str:
|
| 1333 |
+
options = {
|
| 1334 |
+
"declare_kernel": self.get_kernel_declaration(),
|
| 1335 |
+
"kernel": kernel,
|
| 1336 |
+
"block_m": self.register_blocking.block_m,
|
| 1337 |
+
"block_n": self.register_blocking.block_n,
|
| 1338 |
+
"block_k": self.register_blocking.block_k,
|
| 1339 |
+
"restrict_keyword": get_restrict_keyword(),
|
| 1340 |
+
**self.get_common_options(),
|
| 1341 |
+
}
|
| 1342 |
+
result = ""
|
| 1343 |
+
result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
|
| 1344 |
+
options
|
| 1345 |
+
)
|
| 1346 |
+
return result
|
| 1347 |
+
|
| 1348 |
+
def codegen_finalize(
|
| 1349 |
+
self,
|
| 1350 |
+
kernel: CppTemplateKernel,
|
| 1351 |
+
) -> str:
|
| 1352 |
+
return "at::native::cpublas::brgemm_release();"
|
| 1353 |
+
|
| 1354 |
+
def get_b_layout(self):
|
| 1355 |
+
assert self.input_dtype == torch.half and torch.cpu._is_amx_fp16_supported()
|
| 1356 |
+
return LayoutType.VNNI2
|
| 1357 |
+
|
| 1358 |
+
|
| 1359 |
+
def check_woq_int4_extra(config, m, n, k, alpha, num_threads, **kwargs):
|
| 1360 |
+
if alpha != 1:
|
| 1361 |
+
return False
|
| 1362 |
+
q_group_size = kwargs.get("q_group_size", None)
|
| 1363 |
+
assert q_group_size is not None
|
| 1364 |
+
if (
|
| 1365 |
+
q_group_size < 32
|
| 1366 |
+
or k % q_group_size != 0
|
| 1367 |
+
or config.register_blocking.block_k > q_group_size
|
| 1368 |
+
):
|
| 1369 |
+
return False
|
| 1370 |
+
return k % config.register_blocking.block_k == 0 and n % 64 == 0
|
| 1371 |
+
|
| 1372 |
+
|
| 1373 |
+
@register_micro_gemm(
|
| 1374 |
+
# TODO: support float/half input
|
| 1375 |
+
*generate_gemm_config(
|
| 1376 |
+
VecAVX512,
|
| 1377 |
+
[(4, 64, 32), (4, 64, 64), (4, 64, 128)],
|
| 1378 |
+
input_dtype=torch.bfloat16,
|
| 1379 |
+
input2_dtype=torch.uint8,
|
| 1380 |
+
output_dtype=torch.float,
|
| 1381 |
+
compute_dtype=torch.float,
|
| 1382 |
+
extra_check=check_woq_int4_extra,
|
| 1383 |
+
),
|
| 1384 |
+
)
|
| 1385 |
+
class CppMicroGemmWoQInt4Avx512(CppMicroGemmFP32Vec):
|
| 1386 |
+
"""
|
| 1387 |
+
This class generates the code for WoQ int4 micro gemm using AVX512 intrinsics.
|
| 1388 |
+
It is based on the corresponding ATen kernel.
|
| 1389 |
+
Shape of packed weight = [N // 64, K, 32], viewed as [N, K // 2]
|
| 1390 |
+
Shape of packed ScalesAndZeros = [K // group_size, N, 2]
|
| 1391 |
+
"""
|
| 1392 |
+
|
| 1393 |
+
TEMPLATE_ENTRY = r"""
|
| 1394 |
+
{{declare_kernel}} {
|
| 1395 |
+
{{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
|
| 1396 |
+
{{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
|
| 1397 |
+
auto group_size = q_group_size;
|
| 1398 |
+
for (int64_t m = 0; m < M; m += {{block_m}}) {
|
| 1399 |
+
int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
|
| 1400 |
+
for (int64_t n = 0; n < N; n += {{block_n}}) {
|
| 1401 |
+
if (block_m == {{block_m}}) {
|
| 1402 |
+
{{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>(
|
| 1403 |
+
A + m * lda,
|
| 1404 |
+
reinterpret_cast<const uint8_t*>(B) + n * ldb,
|
| 1405 |
+
C + m * ldc + n,
|
| 1406 |
+
K,
|
| 1407 |
+
lda,
|
| 1408 |
+
/* ldb */ {{block_n}} / 2,
|
| 1409 |
+
ldc,
|
| 1410 |
+
group_size,
|
| 1411 |
+
ScaleAndZeros + n * 2,
|
| 1412 |
+
lds,
|
| 1413 |
+
k_start
|
| 1414 |
+
);
|
| 1415 |
+
} else {
|
| 1416 |
+
switch (block_m) {
|
| 1417 |
+
{%- for b in range(block_m - 1, 0, -1) %}
|
| 1418 |
+
case {{b}}:
|
| 1419 |
+
{{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>(
|
| 1420 |
+
A + m * lda,
|
| 1421 |
+
reinterpret_cast<const uint8_t*>(B) + n * ldb,
|
| 1422 |
+
C + m * ldc + n,
|
| 1423 |
+
K,
|
| 1424 |
+
lda,
|
| 1425 |
+
/* ldb */ {{block_n}} / 2,
|
| 1426 |
+
ldc,
|
| 1427 |
+
group_size,
|
| 1428 |
+
ScaleAndZeros + n * 2,
|
| 1429 |
+
lds,
|
| 1430 |
+
k_start
|
| 1431 |
+
);
|
| 1432 |
+
break;
|
| 1433 |
+
{%- endfor %}
|
| 1434 |
+
default:
|
| 1435 |
+
{{kernel.assert_function}}(false, "Unsupported block_m: ", block_m);
|
| 1436 |
+
}
|
| 1437 |
+
}
|
| 1438 |
+
}
|
| 1439 |
+
}
|
| 1440 |
+
}
|
| 1441 |
+
"""
|
| 1442 |
+
|
| 1443 |
+
TEMPLATE_KERNEL = r"""
|
| 1444 |
+
inline bool {{kernel_name}}_is_block_start(int index, int k_start, int group_size) {
|
| 1445 |
+
return (k_start + index) % group_size == 0;
|
| 1446 |
+
}
|
| 1447 |
+
|
| 1448 |
+
inline __m128i {{kernel_name}}_convert_int4_to_int8(const uint8_t* data) {
|
| 1449 |
+
__m128i tmp = _mm_loadu_si64((const __m128i*)data);
|
| 1450 |
+
__m128i bytes = _mm_cvtepu8_epi16(tmp);
|
| 1451 |
+
const __m128i lowMask = _mm_set1_epi8(0xF);
|
| 1452 |
+
__m128i high = _mm_andnot_si128(lowMask, bytes);
|
| 1453 |
+
__m128i low = _mm_and_si128(lowMask, bytes);
|
| 1454 |
+
high = _mm_slli_epi16(high, 4);
|
| 1455 |
+
bytes = _mm_or_si128(low, high);
|
| 1456 |
+
return bytes;
|
| 1457 |
+
}
|
| 1458 |
+
|
| 1459 |
+
template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum>
|
| 1460 |
+
inline void {{kernel_name}}_kernel(
|
| 1461 |
+
const {{input_t}}* {{restrict_keyword}} A,
|
| 1462 |
+
const uint8_t* {{restrict_keyword}} B,
|
| 1463 |
+
{{output_t}}* {{restrict_keyword}} C,
|
| 1464 |
+
int64_t K,
|
| 1465 |
+
int64_t lda,
|
| 1466 |
+
int64_t ldb,
|
| 1467 |
+
int64_t ldc,
|
| 1468 |
+
int64_t q_group_size,
|
| 1469 |
+
const at::BFloat16* {{restrict_keyword}} ScaleAndZeros,
|
| 1470 |
+
int64_t lds, // leading dimension of ScaleAndZeros
|
| 1471 |
+
int64_t k_start) {
|
| 1472 |
+
constexpr int BLOCK_K = {{block_k}};
|
| 1473 |
+
constexpr int ROWS = BLOCK_M;
|
| 1474 |
+
constexpr int COLS = BLOCK_N / 16;
|
| 1475 |
+
|
| 1476 |
+
const int PREFETCH_SIZE_K = 16 * 4;
|
| 1477 |
+
const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
|
| 1478 |
+
|
| 1479 |
+
// number of blocks on K
|
| 1480 |
+
const int KB = K / BLOCK_K;
|
| 1481 |
+
|
| 1482 |
+
__m512 va;
|
| 1483 |
+
__m512 vb[COLS];
|
| 1484 |
+
__m512 vc[ROWS * COLS];
|
| 1485 |
+
__m512 scale[COLS];
|
| 1486 |
+
__m512 zero[COLS];
|
| 1487 |
+
|
| 1488 |
+
// Lookup table to de-quantize int4 values to bf16.
|
| 1489 |
+
// Values are dequantized as truly int4 [-8, 7] range;
|
| 1490 |
+
//
|
| 1491 |
+
// dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
|
| 1492 |
+
//
|
| 1493 |
+
static const __m512 lut = _mm512_set_ps(
|
| 1494 |
+
7.0f, 6.0f, 5.0f, 4.0f,
|
| 1495 |
+
3.0f, 2.0f, 1.0f, 0.0f,
|
| 1496 |
+
-1.0f, -2.0f, -3.0f, -4.0f,
|
| 1497 |
+
-5.0f, -6.0f, -7.0f, -8.0f);
|
| 1498 |
+
|
| 1499 |
+
// index for transpose
|
| 1500 |
+
static const __m512i idx1 = _mm512_set_epi32(
|
| 1501 |
+
30, 28, 26, 24, 22, 20, 18, 16,
|
| 1502 |
+
14, 12, 10, 8, 6, 4, 2, 0);
|
| 1503 |
+
static const __m512i idx2 = _mm512_set_epi32(
|
| 1504 |
+
31, 29, 27, 25, 23, 21, 19, 17,
|
| 1505 |
+
15, 13, 11, 9, 7, 5, 3, 1);
|
| 1506 |
+
|
| 1507 |
+
// load scale and zero point
|
| 1508 |
+
auto load_scale_and_zeros = [&](int i, int _kb) {
|
| 1509 |
+
// load 2x bfloat16 vector
|
| 1510 |
+
__m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i));
|
| 1511 |
+
if (_kb + PREFETCH_SIZE_KB < KB) {
|
| 1512 |
+
_mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0);
|
| 1513 |
+
}
|
| 1514 |
+
|
| 1515 |
+
// convert to 2x f32 vector
|
| 1516 |
+
__m512 a, b;
|
| 1517 |
+
at::vec::cvtbf16_fp32(t, a, b);
|
| 1518 |
+
|
| 1519 |
+
// transpose scale_and_zero from {16, 2} to {2, 16}
|
| 1520 |
+
// inputs:
|
| 1521 |
+
// a: {s0, z0, s1, z1, ..., s7, z7}
|
| 1522 |
+
// b: {s8, z8, s9, z9, ..., s15, z15}
|
| 1523 |
+
// output:
|
| 1524 |
+
// scale: {s0, s1, s2, ..., s15}
|
| 1525 |
+
// zero: {z0, z1, z2, ..., z15}
|
| 1526 |
+
scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
|
| 1527 |
+
zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
|
| 1528 |
+
};
|
| 1529 |
+
|
| 1530 |
+
auto loadc = [&](auto i) {
|
| 1531 |
+
if constexpr (accum) {
|
| 1532 |
+
constexpr int row = i / COLS;
|
| 1533 |
+
constexpr int col = i % COLS;
|
| 1534 |
+
vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16);
|
| 1535 |
+
} else {
|
| 1536 |
+
vc[i] = _mm512_setzero_ps();
|
| 1537 |
+
}
|
| 1538 |
+
};
|
| 1539 |
+
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
|
| 1540 |
+
|
| 1541 |
+
auto compute = [&, COLS](auto i, int k) {
|
| 1542 |
+
constexpr int row = i / COLS;
|
| 1543 |
+
constexpr int col = i % COLS;
|
| 1544 |
+
|
| 1545 |
+
if constexpr (col == 0) {
|
| 1546 |
+
float aa = static_cast<float>(A[row * lda + k]);
|
| 1547 |
+
if (k + PREFETCH_SIZE_K < K) {
|
| 1548 |
+
_mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
|
| 1549 |
+
}
|
| 1550 |
+
va = _mm512_set1_ps(aa);
|
| 1551 |
+
}
|
| 1552 |
+
|
| 1553 |
+
if constexpr (row == 0) {
|
| 1554 |
+
if constexpr (COLS == 4) {
|
| 1555 |
+
// when BLOCK_N = 64, handle each row at a time
|
| 1556 |
+
// to reduce de-quantize overhead.
|
| 1557 |
+
if constexpr (col == 0) {
|
| 1558 |
+
__m256i b4 = _mm256_loadu_si256((__m256i*)(B + k * ldb));
|
| 1559 |
+
if (k + PREFETCH_SIZE_K < K) {
|
| 1560 |
+
_mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0);
|
| 1561 |
+
}
|
| 1562 |
+
|
| 1563 |
+
__m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4));
|
| 1564 |
+
vb[0] = _mm512_permutexvar_ps(b32, lut);
|
| 1565 |
+
vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]);
|
| 1566 |
+
vb[2] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
|
| 1567 |
+
vb[2] = _mm512_fmadd_ps(vb[2], scale[2], zero[2]);
|
| 1568 |
+
|
| 1569 |
+
b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1));
|
| 1570 |
+
vb[1] = _mm512_permutexvar_ps(b32, lut);
|
| 1571 |
+
vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]);
|
| 1572 |
+
vb[3] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
|
| 1573 |
+
vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]);
|
| 1574 |
+
}
|
| 1575 |
+
} else {
|
| 1576 |
+
__m128i b8 = {{kernel_name}}_convert_int4_to_int8(B + k * ldb + col * 8);
|
| 1577 |
+
__m512i b32 = _mm512_cvtepu8_epi32(b8);
|
| 1578 |
+
vb[col] = _mm512_permutexvar_ps(b32, lut);
|
| 1579 |
+
vb[col] = _mm512_fmadd_ps(vb[col], scale[col], zero[col]);
|
| 1580 |
+
}
|
| 1581 |
+
}
|
| 1582 |
+
|
| 1583 |
+
constexpr int idx = row * COLS + col;
|
| 1584 |
+
vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
|
| 1585 |
+
};
|
| 1586 |
+
|
| 1587 |
+
for (int k = 0, kb = 0; k < K; ++k) {
|
| 1588 |
+
if ({{kernel_name}}_is_block_start(k, k_start, q_group_size)) {
|
| 1589 |
+
c10::ForcedUnroll<COLS>{}(load_scale_and_zeros, kb++);
|
| 1590 |
+
}
|
| 1591 |
+
c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
|
| 1592 |
+
}
|
| 1593 |
+
|
| 1594 |
+
//store to C
|
| 1595 |
+
auto storec = [&, COLS](auto i) {
|
| 1596 |
+
constexpr int row = i / COLS;
|
| 1597 |
+
constexpr int col = i % COLS;
|
| 1598 |
+
_mm512_storeu_ps(C + row * ldc + col * 16, vc[i]);
|
| 1599 |
+
};
|
| 1600 |
+
c10::ForcedUnroll<ROWS * COLS>{}(storec);
|
| 1601 |
+
}
|
| 1602 |
+
"""
|
| 1603 |
+
|
| 1604 |
+
def get_kernel_extra_args_declare(self) -> str:
|
| 1605 |
+
return (
|
| 1606 |
+
"const int64_t q_group_size,\n"
|
| 1607 |
+
" const at::BFloat16* __restrict__ ScaleAndZeros,\n"
|
| 1608 |
+
" const int64_t lds,\n"
|
| 1609 |
+
" int64_t k_start,"
|
| 1610 |
+
)
|
| 1611 |
+
|
| 1612 |
+
def get_kernel_extra_args(self, **kwargs) -> list[str]:
|
| 1613 |
+
assert "kernel" in kwargs
|
| 1614 |
+
assert "qscale_and_zeros" in kwargs
|
| 1615 |
+
kernel = kwargs["kernel"]
|
| 1616 |
+
qscale_and_zeros = kwargs["qscale_and_zeros"]
|
| 1617 |
+
return [
|
| 1618 |
+
"group_size,",
|
| 1619 |
+
f"&({kernel.index(qscale_and_zeros, [0, 0, 0])}),",
|
| 1620 |
+
"N * 2,", # lds
|
| 1621 |
+
"k_start,",
|
| 1622 |
+
]
|
| 1623 |
+
|
| 1624 |
+
def is_woq_int4(self):
|
| 1625 |
+
return True
|
| 1626 |
+
|
| 1627 |
+
|
| 1628 |
+
@register_micro_gemm(
|
| 1629 |
+
*generate_gemm_config(
|
| 1630 |
+
VecAMX,
|
| 1631 |
+
[ # (block_m, block_n, block_k)
|
| 1632 |
+
(16, 32, 32),
|
| 1633 |
+
(32, 32, 32),
|
| 1634 |
+
],
|
| 1635 |
+
input_dtype=torch.bfloat16,
|
| 1636 |
+
input2_dtype=torch.uint8,
|
| 1637 |
+
output_dtype=torch.float,
|
| 1638 |
+
compute_dtype=torch.float,
|
| 1639 |
+
extra_check=check_amx_extra,
|
| 1640 |
+
),
|
| 1641 |
+
)
|
| 1642 |
+
class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX):
|
| 1643 |
+
"""
|
| 1644 |
+
This class generates the code for WoQ int4 micro gemm using AMX intrinsics,
|
| 1645 |
+
which are available on 4th and newer generations of Intel Xeon.
|
| 1646 |
+
Shape of packed weight = [N // 32, K, 16], viewed as [N, K // 2]
|
| 1647 |
+
Shape of packed ScalesAndZeros = [K // group_size, N, 2]
|
| 1648 |
+
Reuse TEMPLATE_KERNEL of CppMicroGemmAMX.
|
| 1649 |
+
"""
|
| 1650 |
+
|
| 1651 |
+
TEMPLATE_ENTRY = r"""
|
| 1652 |
+
inline bool {{kernel_name}}_is_block_start(int index, int k_start, int group_size) {
|
| 1653 |
+
return (k_start + index) % group_size == 0;
|
| 1654 |
+
}
|
| 1655 |
+
|
| 1656 |
+
{{declare_kernel}} {
|
| 1657 |
+
{{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
|
| 1658 |
+
{{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2");
|
| 1659 |
+
{{kernel.assert_function}}({{block_n}} == 32, "block_n must be 32 for WOQ int4");
|
| 1660 |
+
|
| 1661 |
+
// Create a stack-allocated buffer for tiles of B.
|
| 1662 |
+
// Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements.
|
| 1663 |
+
// we cache K * {{block_n}} elements of dequantized B
|
| 1664 |
+
{{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}}
|
| 1665 |
+
|
| 1666 |
+
constexpr int BLOCK_K = {{block_k}};
|
| 1667 |
+
constexpr int64_t BLOCK_N = {{block_n}};
|
| 1668 |
+
constexpr int COLS = BLOCK_N / 16;
|
| 1669 |
+
const int PREFETCH_SIZE_K = 16 * 4;
|
| 1670 |
+
const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
|
| 1671 |
+
const int KB = K / BLOCK_K;
|
| 1672 |
+
|
| 1673 |
+
__m512i b32[COLS * 2];
|
| 1674 |
+
__m512 vb[COLS * 2];
|
| 1675 |
+
__m512 scale[COLS];
|
| 1676 |
+
__m512 zero[COLS];
|
| 1677 |
+
|
| 1678 |
+
// Lookup table to de-quantize int4 values to bf16.
|
| 1679 |
+
// Values are dequantized as truly int4 [-8, 7] range;
|
| 1680 |
+
//
|
| 1681 |
+
// dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
|
| 1682 |
+
//
|
| 1683 |
+
static const __m512 lut = _mm512_set_ps(
|
| 1684 |
+
7.0f, 6.0f, 5.0f, 4.0f,
|
| 1685 |
+
3.0f, 2.0f, 1.0f, 0.0f,
|
| 1686 |
+
-1.0f, -2.0f, -3.0f, -4.0f,
|
| 1687 |
+
-5.0f, -6.0f, -7.0f, -8.0f);
|
| 1688 |
+
|
| 1689 |
+
// index for transpose
|
| 1690 |
+
static const __m512i idx1 = _mm512_set_epi32(
|
| 1691 |
+
30, 28, 26, 24, 22, 20, 18, 16,
|
| 1692 |
+
14, 12, 10, 8, 6, 4, 2, 0);
|
| 1693 |
+
static const __m512i idx2 = _mm512_set_epi32(
|
| 1694 |
+
31, 29, 27, 25, 23, 21, 19, 17,
|
| 1695 |
+
15, 13, 11, 9, 7, 5, 3, 1);
|
| 1696 |
+
|
| 1697 |
+
// Indices for VNNI layout conversion
|
| 1698 |
+
__m512i idx_low = _mm512_set_epi32(
|
| 1699 |
+
0x17,
|
| 1700 |
+
0x07,
|
| 1701 |
+
0x16,
|
| 1702 |
+
0x06,
|
| 1703 |
+
0x15,
|
| 1704 |
+
0x05,
|
| 1705 |
+
0x14,
|
| 1706 |
+
0x04,
|
| 1707 |
+
0x13,
|
| 1708 |
+
0x03,
|
| 1709 |
+
0x12,
|
| 1710 |
+
0x02,
|
| 1711 |
+
0x11,
|
| 1712 |
+
0x01,
|
| 1713 |
+
0x10,
|
| 1714 |
+
0x00);
|
| 1715 |
+
__m512i idx_high = _mm512_set_epi32(
|
| 1716 |
+
0x1f,
|
| 1717 |
+
0x0f,
|
| 1718 |
+
0x1e,
|
| 1719 |
+
0x0e,
|
| 1720 |
+
0x1d,
|
| 1721 |
+
0x0d,
|
| 1722 |
+
0x1c,
|
| 1723 |
+
0x0c,
|
| 1724 |
+
0x1b,
|
| 1725 |
+
0x0b,
|
| 1726 |
+
0x1a,
|
| 1727 |
+
0x0a,
|
| 1728 |
+
0x19,
|
| 1729 |
+
0x09,
|
| 1730 |
+
0x18,
|
| 1731 |
+
0x08);
|
| 1732 |
+
|
| 1733 |
+
// load scale and zero point
|
| 1734 |
+
auto load_scale_and_zeros = [&](int i, int _kb) {
|
| 1735 |
+
// load 2x bfloat16 vector
|
| 1736 |
+
__m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i));
|
| 1737 |
+
if (_kb + PREFETCH_SIZE_KB < KB) {
|
| 1738 |
+
_mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0);
|
| 1739 |
+
}
|
| 1740 |
+
|
| 1741 |
+
// convert to 2x f32 vector
|
| 1742 |
+
__m512 a, b;
|
| 1743 |
+
at::vec::cvtbf16_fp32(t, a, b);
|
| 1744 |
+
|
| 1745 |
+
// transpose scale_and_zero from {16, 2} to {2, 16}
|
| 1746 |
+
// inputs:
|
| 1747 |
+
// a: {s0, z0, s1, z1, ..., s7, z7}
|
| 1748 |
+
// b: {s8, z8, s9, z9, ..., s15, z15}
|
| 1749 |
+
// output:
|
| 1750 |
+
// scale: {s0, s1, s2, ..., s15}
|
| 1751 |
+
// zero: {z0, z1, z2, ..., z15}
|
| 1752 |
+
scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
|
| 1753 |
+
zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
|
| 1754 |
+
};
|
| 1755 |
+
|
| 1756 |
+
// Dequantize a B block of 2 * block_n into bf16
|
| 1757 |
+
// So, it handles k and k+1 at the same time
|
| 1758 |
+
auto dequantize_B = [&](int n) {
|
| 1759 |
+
constexpr int64_t ldb_int4 = BLOCK_N / 2; // 16
|
| 1760 |
+
for (int k = 0, kb = 0; k < K; k += 2) {
|
| 1761 |
+
// Since block_k must be 32 for AMX microkernels, k_start may not be
|
| 1762 |
+
// a multiple of q_group_size. In that case, we need to load scales
|
| 1763 |
+
// and zero points immediately when k == 0 here
|
| 1764 |
+
if ({{kernel_name}}_is_block_start(k, k_start, q_group_size) || k == 0) {
|
| 1765 |
+
c10::ForcedUnroll<COLS>{}(load_scale_and_zeros, kb++);
|
| 1766 |
+
}
|
| 1767 |
+
|
| 1768 |
+
// load 256 bits = 64 elements in int4
|
| 1769 |
+
if (k + PREFETCH_SIZE_K < K) {
|
| 1770 |
+
_mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb_int4, _MM_HINT_T0);
|
| 1771 |
+
}
|
| 1772 |
+
|
| 1773 |
+
__m128i b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + k * ldb_int4));
|
| 1774 |
+
b32[0] = _mm512_cvtepu8_epi32(b4);
|
| 1775 |
+
b32[1] = _mm512_srli_epi32(b32[0], 4);
|
| 1776 |
+
vb[0] = _mm512_permutexvar_ps(b32[0] , lut);
|
| 1777 |
+
vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]);
|
| 1778 |
+
vb[1] = _mm512_permutexvar_ps(b32[1], lut);
|
| 1779 |
+
vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]);
|
| 1780 |
+
|
| 1781 |
+
b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + (k + 1) * ldb_int4));
|
| 1782 |
+
b32[0 + COLS] = _mm512_cvtepu8_epi32(b4);
|
| 1783 |
+
b32[1 + COLS] = _mm512_srli_epi32(b32[0 + COLS], 4);
|
| 1784 |
+
vb[0 + COLS] = _mm512_permutexvar_ps(b32[0 + COLS] , lut);
|
| 1785 |
+
vb[0 + COLS] = _mm512_fmadd_ps(vb[0 + COLS], scale[0], zero[0]);
|
| 1786 |
+
vb[1 + COLS] = _mm512_permutexvar_ps(b32[1 + COLS], lut);
|
| 1787 |
+
vb[1 + COLS] = _mm512_fmadd_ps(vb[1 + COLS], scale[1], zero[1]);
|
| 1788 |
+
|
| 1789 |
+
for (int i = 0; i < COLS; i++) {
|
| 1790 |
+
// convert to VNNI
|
| 1791 |
+
auto low = _mm512_permutex2var_ps(vb[i], idx_low, vb[i + COLS]);
|
| 1792 |
+
auto high = _mm512_permutex2var_ps(vb[i], idx_high, vb[i + COLS]);
|
| 1793 |
+
// convert lower 16 float32 values to bfloat16
|
| 1794 |
+
auto v0_bf16 = reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(low));
|
| 1795 |
+
// convert higher 16 float32 values to bfloat16
|
| 1796 |
+
auto v1_bf16 = reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(high));
|
| 1797 |
+
// combine the lower 16 and higher 16 bfloat16 values
|
| 1798 |
+
auto v = _mm512_castsi256_si512(v0_bf16);
|
| 1799 |
+
v = _mm512_inserti64x4(v, v1_bf16, 1);
|
| 1800 |
+
// store the VNNI format bfloat16 values
|
| 1801 |
+
{{input_t}}* addr = dequantized_B_buf + k * 32 + (i % 2) * 32;
|
| 1802 |
+
_mm512_storeu_si512(addr, v);
|
| 1803 |
+
}
|
| 1804 |
+
}
|
| 1805 |
+
};
|
| 1806 |
+
|
| 1807 |
+
for (int64_t n = 0; n < N; n += {{block_n}}) {
|
| 1808 |
+
// Dequantize K * block_n int8 B elements into BF16
|
| 1809 |
+
dequantize_B(n);
|
| 1810 |
+
for (int64_t m = 0; m < M; m += {{block_m}}) {
|
| 1811 |
+
int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
|
| 1812 |
+
int64_t m_tail = m;
|
| 1813 |
+
{%- for num_rows in range(block_m, 0, -16) %}
|
| 1814 |
+
{%- if num_rows != block_m %}
|
| 1815 |
+
else
|
| 1816 |
+
{%- endif %}
|
| 1817 |
+
if (block_m >= {{num_rows}}) {
|
| 1818 |
+
{{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}<accum>(
|
| 1819 |
+
amx_state,
|
| 1820 |
+
A + m * lda,
|
| 1821 |
+
dequantized_B_buf + n * K,
|
| 1822 |
+
C + m * ldc + n,
|
| 1823 |
+
K,
|
| 1824 |
+
lda,
|
| 1825 |
+
{{block_n}},
|
| 1826 |
+
ldc,
|
| 1827 |
+
16
|
| 1828 |
+
);
|
| 1829 |
+
block_m -= {{num_rows}};
|
| 1830 |
+
m_tail += {{num_rows}};
|
| 1831 |
+
}
|
| 1832 |
+
{%- endfor %}
|
| 1833 |
+
if (block_m > 0) {
|
| 1834 |
+
{{kernel_name}}_amx_kernel_16_{{num_columns}}<accum>(
|
| 1835 |
+
amx_state,
|
| 1836 |
+
A + m_tail * lda,
|
| 1837 |
+
dequantized_B_buf + n * K,
|
| 1838 |
+
C + m_tail * ldc + n,
|
| 1839 |
+
K,
|
| 1840 |
+
lda,
|
| 1841 |
+
{{block_n}},
|
| 1842 |
+
ldc,
|
| 1843 |
+
block_m
|
| 1844 |
+
);
|
| 1845 |
+
}
|
| 1846 |
+
} // for m
|
| 1847 |
+
} // for n
|
| 1848 |
+
}
|
| 1849 |
+
"""
|
| 1850 |
+
|
| 1851 |
+
def get_kernel_extra_args_declare(self) -> str:
|
| 1852 |
+
return (
|
| 1853 |
+
"AMXState& amx_state,\n"
|
| 1854 |
+
" const int64_t q_group_size,\n"
|
| 1855 |
+
" const c10::BFloat16* __restrict__ ScaleAndZeros,\n"
|
| 1856 |
+
" const int64_t lds,\n"
|
| 1857 |
+
" int64_t k_start,"
|
| 1858 |
+
)
|
| 1859 |
+
|
| 1860 |
+
def get_kernel_extra_args(self, **kwargs) -> list[str]:
|
| 1861 |
+
assert "kernel" in kwargs
|
| 1862 |
+
assert "qscale_and_zeros" in kwargs
|
| 1863 |
+
kernel = kwargs["kernel"]
|
| 1864 |
+
qscale_and_zeros = kwargs["qscale_and_zeros"]
|
| 1865 |
+
return [
|
| 1866 |
+
"amx_state,",
|
| 1867 |
+
"group_size,",
|
| 1868 |
+
f"&({kernel.index(qscale_and_zeros, [0, 0, 0])}),",
|
| 1869 |
+
"N * 2,", # lds
|
| 1870 |
+
"k_start,",
|
| 1871 |
+
]
|
| 1872 |
+
|
| 1873 |
+
def is_woq_int4(self):
|
| 1874 |
+
return True
|
| 1875 |
+
|
| 1876 |
+
|
| 1877 |
+
def create_micro_gemm(
|
| 1878 |
+
name,
|
| 1879 |
+
m,
|
| 1880 |
+
n,
|
| 1881 |
+
k,
|
| 1882 |
+
input_dtype,
|
| 1883 |
+
input2_dtype,
|
| 1884 |
+
output_dtype=None,
|
| 1885 |
+
compute_dtype=None,
|
| 1886 |
+
alpha=1,
|
| 1887 |
+
num_threads=-1,
|
| 1888 |
+
use_ref=True,
|
| 1889 |
+
q_group_size=None,
|
| 1890 |
+
) -> Optional[CppMicroGemm]:
|
| 1891 |
+
"""
|
| 1892 |
+
Based on the provided info, try to find the config of the micro-kernel that would
|
| 1893 |
+
deliver the best performance in terms of lower latency for this case.
|
| 1894 |
+
"""
|
| 1895 |
+
|
| 1896 |
+
def create_from_config(cls, config: CppMicroGemmConfig):
|
| 1897 |
+
return cls(
|
| 1898 |
+
name,
|
| 1899 |
+
config.input_dtype,
|
| 1900 |
+
config.input2_dtype,
|
| 1901 |
+
config.output_dtype,
|
| 1902 |
+
config.compute_dtype,
|
| 1903 |
+
config.register_blocking,
|
| 1904 |
+
alpha,
|
| 1905 |
+
)
|
| 1906 |
+
|
| 1907 |
+
def skip_amx_kernel_for_woq(config, dynamic_M, micro_gemm_cls):
|
| 1908 |
+
# For WoQ GEMM, AMX micro-kernel may not perform well if m is small.
|
| 1909 |
+
# Exception: for dynamic shapes, we consider using the AMX micro-kernel.
|
| 1910 |
+
if (
|
| 1911 |
+
dynamic_M
|
| 1912 |
+
or input_dtype != torch.bfloat16
|
| 1913 |
+
or input2_dtype not in [torch.int8, torch.uint8]
|
| 1914 |
+
):
|
| 1915 |
+
return False
|
| 1916 |
+
# For WOQ INT8, use AMX for m >= block_m
|
| 1917 |
+
# For WOQ INT4, use AMX for m >= 5
|
| 1918 |
+
block_m, *_ = config.register_blocking
|
| 1919 |
+
is_woq_int4 = micro_gemm_cls == CppMicroGemmWoQInt4Amx
|
| 1920 |
+
m_threshold = 5 if is_woq_int4 else block_m
|
| 1921 |
+
return m < m_threshold
|
| 1922 |
+
|
| 1923 |
+
assert isinstance(n, int) or n.is_number, n
|
| 1924 |
+
assert isinstance(k, int) or k.is_number, k
|
| 1925 |
+
from ..utils import has_free_symbols
|
| 1926 |
+
|
| 1927 |
+
dynamic_M = has_free_symbols((m,))
|
| 1928 |
+
m = V.graph.sizevars.size_hint(m, fallback=1) if dynamic_M else m
|
| 1929 |
+
assert isinstance(m, int) or m.is_number, m
|
| 1930 |
+
if output_dtype is None:
|
| 1931 |
+
output_dtype = input_dtype
|
| 1932 |
+
if compute_dtype is None:
|
| 1933 |
+
compute_dtype = output_dtype
|
| 1934 |
+
if num_threads < 0:
|
| 1935 |
+
num_threads = parallel_num_threads()
|
| 1936 |
+
vec_isa = pick_vec_isa()
|
| 1937 |
+
matched_configs = []
|
| 1938 |
+
for cls, configs in micro_gemm_configs.items():
|
| 1939 |
+
for config in configs:
|
| 1940 |
+
if not issubclass(vec_isa.__class__, config.vec_isa_cls):
|
| 1941 |
+
continue
|
| 1942 |
+
if (
|
| 1943 |
+
config.input_dtype == input_dtype
|
| 1944 |
+
and config.compute_dtype == compute_dtype
|
| 1945 |
+
and config.input2_dtype == input2_dtype
|
| 1946 |
+
and config.output_dtype == output_dtype
|
| 1947 |
+
# The output_dtype here is the output dtype of the micro-kernel.
|
| 1948 |
+
# In some cases, the actual output dtype of the op for which the micro-kernel
|
| 1949 |
+
# is being created would be same as that of the activation, but the micro-kernels
|
| 1950 |
+
# compute output in Float/int32, which is converted in the GEMM template. This is
|
| 1951 |
+
# subject to change in the future.
|
| 1952 |
+
):
|
| 1953 |
+
if config.extra_check is not None and not config.extra_check(
|
| 1954 |
+
config,
|
| 1955 |
+
m,
|
| 1956 |
+
n,
|
| 1957 |
+
k,
|
| 1958 |
+
alpha,
|
| 1959 |
+
num_threads,
|
| 1960 |
+
dynamic_M=dynamic_M,
|
| 1961 |
+
q_group_size=q_group_size,
|
| 1962 |
+
):
|
| 1963 |
+
continue
|
| 1964 |
+
block_m, block_n, block_k = config.register_blocking
|
| 1965 |
+
if config.vec_isa_cls == VecAMX and skip_amx_kernel_for_woq(
|
| 1966 |
+
config, dynamic_M, cls
|
| 1967 |
+
):
|
| 1968 |
+
continue
|
| 1969 |
+
# Criteria on the ranking of configurations
|
| 1970 |
+
# 1. ISA: AMX > VEC
|
| 1971 |
+
# 2. Dividable by block sizes (block_m, block_n, block_k)
|
| 1972 |
+
# 3. Number of mxn blocks is large enough to occupy all the threads
|
| 1973 |
+
# 4. Register blocks are larger
|
| 1974 |
+
isa_score = 0
|
| 1975 |
+
if config.vec_isa_cls == VecAMX:
|
| 1976 |
+
isa_score += 1
|
| 1977 |
+
dividable_score = 0
|
| 1978 |
+
if m % block_m == 0:
|
| 1979 |
+
dividable_score += 1
|
| 1980 |
+
if n % block_n == 0:
|
| 1981 |
+
dividable_score += 1
|
| 1982 |
+
if k % block_k == 0:
|
| 1983 |
+
dividable_score += 1
|
| 1984 |
+
occupancy_score = 0
|
| 1985 |
+
n_blocks = (n + block_n - 1) // block_n
|
| 1986 |
+
total_mxn_blocks = n_blocks * ((m + block_m - 1) // block_m)
|
| 1987 |
+
if n_blocks >= num_threads:
|
| 1988 |
+
occupancy_score += 1
|
| 1989 |
+
if total_mxn_blocks >= num_threads:
|
| 1990 |
+
occupancy_score += 1
|
| 1991 |
+
register_bytes = (
|
| 1992 |
+
block_m * block_n * config.compute_dtype.itemsize
|
| 1993 |
+
+ (block_m * block_k + block_k * block_n)
|
| 1994 |
+
* config.input_dtype.itemsize
|
| 1995 |
+
)
|
| 1996 |
+
matched_configs.append(
|
| 1997 |
+
(
|
| 1998 |
+
(isa_score, dividable_score, occupancy_score, register_bytes),
|
| 1999 |
+
cls,
|
| 2000 |
+
config,
|
| 2001 |
+
)
|
| 2002 |
+
)
|
| 2003 |
+
if len(matched_configs) == 0:
|
| 2004 |
+
if use_ref:
|
| 2005 |
+
return CppMicroGemmRef(
|
| 2006 |
+
name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha
|
| 2007 |
+
)
|
| 2008 |
+
else:
|
| 2009 |
+
return None
|
| 2010 |
+
# TODO(jgong5): allow autotuning on choices of configs
|
| 2011 |
+
return create_from_config(*max(matched_configs, key=operator.itemgetter(0))[1:])
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import ctypes
|
| 3 |
+
import functools
|
| 4 |
+
import itertools
|
| 5 |
+
import logging
|
| 6 |
+
import sys
|
| 7 |
+
from collections.abc import Iterable
|
| 8 |
+
from typing import Callable, Optional, Union
|
| 9 |
+
from unittest.mock import patch
|
| 10 |
+
|
| 11 |
+
import sympy
|
| 12 |
+
|
| 13 |
+
from .. import config, ir
|
| 14 |
+
from ..autotune_process import CppBenchmarkRequest, TensorMeta
|
| 15 |
+
from ..utils import IndentedBuffer, Placeholder, unique
|
| 16 |
+
from ..virtualized import V
|
| 17 |
+
from .common import KernelTemplate
|
| 18 |
+
from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
log = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class CppTemplate(KernelTemplate):
|
| 25 |
+
index_counter = itertools.count()
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
name: str,
|
| 30 |
+
input_nodes,
|
| 31 |
+
layout: ir.Layout,
|
| 32 |
+
num_threads: int,
|
| 33 |
+
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
| 34 |
+
) -> None:
|
| 35 |
+
super().__init__(name)
|
| 36 |
+
self.input_nodes = input_nodes
|
| 37 |
+
self.index = next(self.index_counter)
|
| 38 |
+
self.output_node: Union[ir.Buffer, list[ir.Buffer]] = ir.Buffer(
|
| 39 |
+
name=f"buf_out{self.index}", layout=layout
|
| 40 |
+
)
|
| 41 |
+
self.layout = layout
|
| 42 |
+
self.num_threads = num_threads
|
| 43 |
+
self.epilogue_creator = epilogue_creator
|
| 44 |
+
|
| 45 |
+
def generate(self, **kwargs):
|
| 46 |
+
kernel_name = f"cpp_{self.name}"
|
| 47 |
+
with (
|
| 48 |
+
patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
|
| 49 |
+
patch.object(ir.FlexibleLayout, "allow_indexing", True),
|
| 50 |
+
V.graph.set_current_device(self.layout.device),
|
| 51 |
+
CppTemplateKernel(
|
| 52 |
+
kernel_name=kernel_name, num_threads=self.num_threads
|
| 53 |
+
) as kernel,
|
| 54 |
+
):
|
| 55 |
+
code = kernel.render(self, **kwargs)
|
| 56 |
+
_, call_args, _, _ = kernel.args.python_argdefs()
|
| 57 |
+
log.debug("Generated Code:\n%s", code)
|
| 58 |
+
log.debug(
|
| 59 |
+
"Args: cpp_argdefs: %s, python_argdefs: %s",
|
| 60 |
+
kernel.args.cpp_argdefs(),
|
| 61 |
+
kernel.args.python_argdefs(),
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
expected_args = list(
|
| 65 |
+
unique(input_node.get_name() for input_node in self.input_nodes)
|
| 66 |
+
)
|
| 67 |
+
if isinstance(self.output_node, Iterable):
|
| 68 |
+
expected_args.extend([node.get_name() for node in self.output_node])
|
| 69 |
+
else:
|
| 70 |
+
expected_args.extend([self.output_node.get_name()])
|
| 71 |
+
assert list(call_args)[: len(expected_args)] == expected_args, (
|
| 72 |
+
call_args,
|
| 73 |
+
expected_args,
|
| 74 |
+
)
|
| 75 |
+
extra_args = V.graph.sizevars.size_hints(
|
| 76 |
+
map(sympy.expand, call_args[len(expected_args) :])
|
| 77 |
+
)
|
| 78 |
+
# Cast the size hint from int to ctypes.c_ulonglong explicitly
|
| 79 |
+
# since in cpp kernel, we bind it to C long
|
| 80 |
+
extra_args = tuple(ctypes.c_ulonglong(x) for x in extra_args)
|
| 81 |
+
|
| 82 |
+
kernel_hash_name = f"cpp_{self.name}_{self.index}"
|
| 83 |
+
|
| 84 |
+
# Create the BenchmarkRequest for CPP
|
| 85 |
+
bmreq = CppBenchmarkRequest(
|
| 86 |
+
kernel_name=kernel_name,
|
| 87 |
+
input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
|
| 88 |
+
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
|
| 89 |
+
extra_args=extra_args,
|
| 90 |
+
source_code=code,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def make_kernel_render(
|
| 94 |
+
template_node: ir.CppTemplateBuffer,
|
| 95 |
+
flag_template_buffer_has_other_users: bool,
|
| 96 |
+
epilogue_nodes: Optional[list[ir.IRNode]] = None,
|
| 97 |
+
):
|
| 98 |
+
kernel = CppTemplateKernel(
|
| 99 |
+
kernel_name=str(Placeholder.KERNEL_NAME), num_threads=self.num_threads
|
| 100 |
+
)
|
| 101 |
+
render = functools.partial(
|
| 102 |
+
kernel.render,
|
| 103 |
+
self,
|
| 104 |
+
template_buffer_node=template_node,
|
| 105 |
+
flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
|
| 106 |
+
epilogue_nodes=epilogue_nodes,
|
| 107 |
+
**kwargs,
|
| 108 |
+
)
|
| 109 |
+
return kernel, render
|
| 110 |
+
|
| 111 |
+
return CppTemplateCaller(
|
| 112 |
+
kernel_hash_name,
|
| 113 |
+
self.name,
|
| 114 |
+
self.input_nodes,
|
| 115 |
+
self.output_node[0].get_layout()
|
| 116 |
+
if isinstance(self.output_node, Iterable)
|
| 117 |
+
else self.output_node.get_layout(),
|
| 118 |
+
make_kernel_render,
|
| 119 |
+
bmreq,
|
| 120 |
+
self,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def header(self) -> IndentedBuffer:
|
| 124 |
+
res = IndentedBuffer()
|
| 125 |
+
res.writeline("#include <torch/csrc/inductor/cpp_prefix.h>")
|
| 126 |
+
# TODO: add c10::ForcedUnroll test to test_aoti_abi_check
|
| 127 |
+
res.splice("""#include <c10/util/Unroll.h>""")
|
| 128 |
+
res.splice("""#include <torch/csrc/inductor/aoti_torch/c/shim.h>""")
|
| 129 |
+
enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
|
| 130 |
+
"linux",
|
| 131 |
+
"win32",
|
| 132 |
+
]
|
| 133 |
+
if enable_kernel_profile:
|
| 134 |
+
res.writelines(["#include <ATen/record_function.h>"])
|
| 135 |
+
return res
|
| 136 |
+
|
| 137 |
+
def render(self, **kwargs) -> str:
|
| 138 |
+
raise NotImplementedError
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template_kernel.py
ADDED
|
@@ -0,0 +1,597 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import itertools
|
| 3 |
+
from collections.abc import Iterable
|
| 4 |
+
from typing import Any, Callable, Optional, Union
|
| 5 |
+
|
| 6 |
+
import sympy
|
| 7 |
+
from sympy.parsing.sympy_parser import parse_expr
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch._inductor.utils import do_bench_using_profiling
|
| 11 |
+
from torch.utils._ordered_set import OrderedSet
|
| 12 |
+
from torch.utils._sympy.symbol import SymT
|
| 13 |
+
|
| 14 |
+
from .. import config, cpp_builder, ir, lowering as L
|
| 15 |
+
from ..autotune_process import CppBenchmarkRequest
|
| 16 |
+
from ..loop_body import LoopBody
|
| 17 |
+
from ..select_algorithm import PartialRender
|
| 18 |
+
from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix
|
| 19 |
+
from ..virtualized import V
|
| 20 |
+
from .common import REMOVED
|
| 21 |
+
from .cpp import CppKernel, CppKernelProxy, KernelGroup
|
| 22 |
+
from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def parse_expr_with_index_symbols(expr):
|
| 26 |
+
if isinstance(expr, sympy.Expr):
|
| 27 |
+
return expr
|
| 28 |
+
elif isinstance(expr, (list, tuple)):
|
| 29 |
+
return [parse_expr_with_index_symbols(e) for e in expr]
|
| 30 |
+
else:
|
| 31 |
+
expr = parse_expr(str(expr))
|
| 32 |
+
int_symbols = {sym: sympy_index_symbol(sym.name) for sym in expr.free_symbols}
|
| 33 |
+
return expr.subs(int_symbols)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def wrap_with_tensorbox(node) -> ir.TensorBox:
|
| 37 |
+
return (
|
| 38 |
+
ir.TensorBox.create(node) if isinstance(node, ir.Buffer) else ir.TensorBox(node)
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class CppTemplateKernel(CppKernel):
|
| 43 |
+
def __init__(self, kernel_name, num_threads):
|
| 44 |
+
super().__init__(None, num_threads)
|
| 45 |
+
self.kernel_name = kernel_name
|
| 46 |
+
self.render_hooks = {}
|
| 47 |
+
self.local_buffers = {}
|
| 48 |
+
|
| 49 |
+
def render(self, template, **kwargs):
|
| 50 |
+
return PartialRender(
|
| 51 |
+
template.render(kernel=self, **kwargs), self.render_hooks
|
| 52 |
+
).finalize_all()
|
| 53 |
+
|
| 54 |
+
def def_kernel(
|
| 55 |
+
self,
|
| 56 |
+
inputs: dict[str, ir.Buffer],
|
| 57 |
+
outputs: dict[str, ir.Buffer],
|
| 58 |
+
aliases: Optional[dict[str, str]] = None,
|
| 59 |
+
function_name: str = "",
|
| 60 |
+
extra_sizevars: Optional[list[sympy.Expr]] = None,
|
| 61 |
+
placeholder: str = "<DEF_KERNEL>",
|
| 62 |
+
) -> str:
|
| 63 |
+
if len(function_name) == 0:
|
| 64 |
+
function_name = str(self.kernel_name)
|
| 65 |
+
for name, inp in inputs.items():
|
| 66 |
+
if inp is not None:
|
| 67 |
+
self.args.input_buffers[inp.get_name()] = name
|
| 68 |
+
for name, out in outputs.items():
|
| 69 |
+
self.args.output_buffers[out.get_name()] = name
|
| 70 |
+
if aliases is not None:
|
| 71 |
+
for alias, orig in aliases.items():
|
| 72 |
+
if orig in self.args.input_buffers:
|
| 73 |
+
self.args.input_buffers[alias] = self.args.input_buffers[orig]
|
| 74 |
+
if orig in self.args.output_buffers:
|
| 75 |
+
self.args.output_buffers[alias] = self.args.output_buffers[orig]
|
| 76 |
+
|
| 77 |
+
unique_sizevars = OrderedSet(
|
| 78 |
+
s
|
| 79 |
+
for input in inputs.values()
|
| 80 |
+
if input is not None
|
| 81 |
+
for sym in itertools.chain(input.get_size(), input.get_stride())
|
| 82 |
+
if isinstance(sym, sympy.Expr)
|
| 83 |
+
for s in sym.free_symbols
|
| 84 |
+
)
|
| 85 |
+
unique_sizevars.update(
|
| 86 |
+
s
|
| 87 |
+
for sym in extra_sizevars or []
|
| 88 |
+
if isinstance(sym, sympy.Expr)
|
| 89 |
+
for s in sym.free_symbols
|
| 90 |
+
)
|
| 91 |
+
unique_sizevars.update(
|
| 92 |
+
s
|
| 93 |
+
for output in outputs.values()
|
| 94 |
+
for sym in itertools.chain(output.get_size(), output.get_stride())
|
| 95 |
+
if isinstance(sym, sympy.Expr)
|
| 96 |
+
for s in sym.free_symbols
|
| 97 |
+
)
|
| 98 |
+
sizevars = sorted(unique_sizevars, key=str)
|
| 99 |
+
for sizevar in sizevars:
|
| 100 |
+
self.args.sizevars[sizevar] = f"k{sizevar}"
|
| 101 |
+
|
| 102 |
+
def hook():
|
| 103 |
+
# remove all aliases before generate function definition
|
| 104 |
+
if aliases is not None:
|
| 105 |
+
for alias in aliases:
|
| 106 |
+
if alias in self.args.input_buffers:
|
| 107 |
+
raise AssertionError(
|
| 108 |
+
f"input_buffers cannot be removed: {alias}"
|
| 109 |
+
)
|
| 110 |
+
if alias in self.args.output_buffers:
|
| 111 |
+
self.args.output_buffers[alias] = REMOVED
|
| 112 |
+
cpp_argdefs, _, _ = self.args.cpp_argdefs()
|
| 113 |
+
return f"void {function_name}({', '.join(cpp_argdefs)})"
|
| 114 |
+
|
| 115 |
+
assert placeholder not in self.render_hooks
|
| 116 |
+
self.render_hooks[placeholder] = hook
|
| 117 |
+
return placeholder
|
| 118 |
+
|
| 119 |
+
def call_kernel(self, name: str, node: ir.CppTemplateBuffer):
|
| 120 |
+
wrapper = V.graph.wrapper_code
|
| 121 |
+
_, call_args, arg_types = self.args.cpp_argdefs()
|
| 122 |
+
wrapper.generate_kernel_call(name, call_args, triton=False, arg_types=arg_types)
|
| 123 |
+
|
| 124 |
+
def dtype(self, node: ir.Buffer) -> str:
|
| 125 |
+
return DTYPE_TO_CPP[node.get_dtype()]
|
| 126 |
+
|
| 127 |
+
def acc_dtype(self, node: ir.Buffer) -> str:
|
| 128 |
+
if node.get_dtype() in [torch.float32, torch.bfloat16, torch.half]:
|
| 129 |
+
return "float"
|
| 130 |
+
else:
|
| 131 |
+
raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}")
|
| 132 |
+
|
| 133 |
+
def size(self, node: ir.Buffer, dim: int) -> str:
|
| 134 |
+
return cexpr_index(self.rename_indexing(node.get_size()[dim]))
|
| 135 |
+
|
| 136 |
+
def stride(self, node: ir.Buffer, dim: int) -> str:
|
| 137 |
+
return cexpr_index(self.rename_indexing(node.get_stride()[dim]))
|
| 138 |
+
|
| 139 |
+
def index(self, node: ir.Buffer, indices: list[Any]) -> str:
|
| 140 |
+
indexer = node.get_layout().as_fixed().make_indexer()
|
| 141 |
+
index = indexer(parse_expr_with_index_symbols(indices))
|
| 142 |
+
index = self.rename_indexing(index)
|
| 143 |
+
outer_name = node.get_name()
|
| 144 |
+
inner_name = (
|
| 145 |
+
outer_name
|
| 146 |
+
if outer_name in self.local_buffers
|
| 147 |
+
else self.args.input(node.get_name())
|
| 148 |
+
)
|
| 149 |
+
return f"{inner_name}[{cexpr_index(index)}]"
|
| 150 |
+
|
| 151 |
+
def slice_nd(self, node, ranges: list[tuple[Any, Any]]) -> ir.ReinterpretView:
|
| 152 |
+
"""
|
| 153 |
+
Slice the given node with a list of ranges (start and end) corresponding to its dims.
|
| 154 |
+
The dim is not sliced if the corresponding range is empty.
|
| 155 |
+
"""
|
| 156 |
+
assert len(ranges) == len(node.get_size()), f"{ranges=}, {node=}"
|
| 157 |
+
sliced = wrap_with_tensorbox(node)
|
| 158 |
+
for dim, _range in enumerate(ranges):
|
| 159 |
+
if len(_range) == 0:
|
| 160 |
+
continue
|
| 161 |
+
assert len(_range) == 2
|
| 162 |
+
start, end = parse_expr_with_index_symbols(_range)
|
| 163 |
+
sliced = L.slice_(sliced, dim, start, end, clamp=False)
|
| 164 |
+
assert isinstance(sliced.data, ir.ReinterpretView), sliced.data
|
| 165 |
+
return sliced.data
|
| 166 |
+
|
| 167 |
+
def select(self, node, dim: int, idx: int) -> ir.ReinterpretView:
|
| 168 |
+
# We avoid using L.select here because we need clamp=False so the dim after slicing
|
| 169 |
+
# is 1 instead of a sympy expression of symbol - dim_size.
|
| 170 |
+
node = wrap_with_tensorbox(node)
|
| 171 |
+
idx = ir.View.handle_negative_index(idx, node.get_size()[dim])
|
| 172 |
+
sliced = L.squeeze(L.slice_(node, dim, idx, idx + 1, clamp=False), dim)
|
| 173 |
+
assert isinstance(sliced.data, ir.ReinterpretView), sliced.data
|
| 174 |
+
return sliced.data
|
| 175 |
+
|
| 176 |
+
def view(self, node, sizes: list[Any]) -> ir.View:
|
| 177 |
+
node = wrap_with_tensorbox(node)
|
| 178 |
+
sizes = parse_expr_with_index_symbols(sizes)
|
| 179 |
+
return L.view(node, sizes).data
|
| 180 |
+
|
| 181 |
+
def permute(self, node, dims):
|
| 182 |
+
node = wrap_with_tensorbox(node)
|
| 183 |
+
permuted = L.permute(node, dims).data
|
| 184 |
+
assert isinstance(permuted, ir.ReinterpretView)
|
| 185 |
+
return permuted
|
| 186 |
+
|
| 187 |
+
def maybe_codegen_profile(self) -> str:
|
| 188 |
+
if config.cpp.enable_kernel_profile:
|
| 189 |
+
graph_id = V.graph.graph_id
|
| 190 |
+
prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else ""
|
| 191 |
+
return f'RECORD_FUNCTION("{prefix}{self.kernel_name}", c10::ArrayRef<c10::IValue>({{}}));'
|
| 192 |
+
else:
|
| 193 |
+
return ""
|
| 194 |
+
|
| 195 |
+
def unroll_pragma(self, unroll):
|
| 196 |
+
if cpp_builder.is_gcc():
|
| 197 |
+
return f"#pragma GCC unroll {unroll}"
|
| 198 |
+
else:
|
| 199 |
+
return f"#pragma unroll {unroll}"
|
| 200 |
+
|
| 201 |
+
def define_buffer(self, name, sizes: list[Any], dtype=torch.float) -> str:
|
| 202 |
+
"""Define kernel local buffer"""
|
| 203 |
+
sizes = parse_expr_with_index_symbols(sizes)
|
| 204 |
+
buf = ir.Buffer(
|
| 205 |
+
name=name, layout=ir.FixedLayout(torch.device("cpu"), dtype, sizes)
|
| 206 |
+
)
|
| 207 |
+
self.local_buffers[name] = buf
|
| 208 |
+
ctype = f"{DTYPE_TO_CPP[dtype]}"
|
| 209 |
+
numel = f"{cexpr_index(buf.get_numel())}"
|
| 210 |
+
return f"auto _{name} = std::make_unique<{ctype}[]>({numel}); auto {name} = _{name}.get();"
|
| 211 |
+
|
| 212 |
+
def define_stack_allocated_buffer(
|
| 213 |
+
self, name, sizes: list[Any], dtype=torch.float
|
| 214 |
+
) -> str:
|
| 215 |
+
"""Define stack-allocated buffer"""
|
| 216 |
+
sizes = parse_expr_with_index_symbols(sizes)
|
| 217 |
+
buf = ir.Buffer(
|
| 218 |
+
name=name, layout=ir.FixedLayout(torch.device("cpu"), dtype, sizes)
|
| 219 |
+
)
|
| 220 |
+
self.local_buffers[name] = buf
|
| 221 |
+
ctype = f"{DTYPE_TO_CPP[dtype]}"
|
| 222 |
+
numel = f"{cexpr_index(buf.get_numel())}"
|
| 223 |
+
return f"alignas(64) {ctype} _{name}[{numel}]; {ctype}* {name} = _{name};"
|
| 224 |
+
|
| 225 |
+
def reinit_buffer_if_null(self, name):
|
| 226 |
+
"""Reinit the previously defined local buffer if it is null"""
|
| 227 |
+
assert name in self.local_buffers
|
| 228 |
+
buf = self.local_buffers[name]
|
| 229 |
+
ctype = f"{DTYPE_TO_CPP[buf.layout.dtype]}"
|
| 230 |
+
numel = f"{cexpr_index(buf.get_numel())}"
|
| 231 |
+
return f"if (_{name} == nullptr) {{ _{name} = std::make_unique<{ctype}[]>({numel}); {name} = _{name}.get(); }}"
|
| 232 |
+
|
| 233 |
+
def release_buffer(self, name):
|
| 234 |
+
"""Codegen the code to release the ownership of a local buffer to others"""
|
| 235 |
+
assert name in self.local_buffers
|
| 236 |
+
return f"_{name}.release()"
|
| 237 |
+
|
| 238 |
+
def store_pointwise_nodes(
|
| 239 |
+
self,
|
| 240 |
+
dst: ir.Buffer,
|
| 241 |
+
nodes: list[ir.IRNode],
|
| 242 |
+
offsets: Optional[list[sympy.Expr]] = None,
|
| 243 |
+
reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None,
|
| 244 |
+
) -> str:
|
| 245 |
+
var_sizes = (tuple(dst.get_size()), ())
|
| 246 |
+
var_ranges = {
|
| 247 |
+
sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
|
| 248 |
+
for i, sz in enumerate(var_sizes[0])
|
| 249 |
+
}
|
| 250 |
+
if not offsets:
|
| 251 |
+
offsets = [sympy.S.Zero] * len(var_sizes[0])
|
| 252 |
+
if not reindexers:
|
| 253 |
+
reindexers = [None] * len(nodes)
|
| 254 |
+
assert len(offsets) == len(var_sizes[0])
|
| 255 |
+
output_index = dst.get_layout().make_indexer()([*var_ranges.keys()])
|
| 256 |
+
kernel_group = KernelGroup()
|
| 257 |
+
kernel_group.args = self.args
|
| 258 |
+
cpp_kernel_proxy = CppKernelProxy(kernel_group)
|
| 259 |
+
bodies = []
|
| 260 |
+
var_sizes_list = []
|
| 261 |
+
for i, node in enumerate(nodes):
|
| 262 |
+
output_name = node.get_name() if i < len(nodes) - 1 else dst.get_name()
|
| 263 |
+
node = node.data if isinstance(node, ir.ComputedBuffer) else node
|
| 264 |
+
assert isinstance(node, ir.Pointwise), node
|
| 265 |
+
|
| 266 |
+
def fn(*args):
|
| 267 |
+
assert len(args) == 2
|
| 268 |
+
assert len(args[0]) == len(var_sizes[0])
|
| 269 |
+
assert len(args[1]) == 0
|
| 270 |
+
new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type]
|
| 271 |
+
if reindexers[i] is not None:
|
| 272 |
+
new_args = reindexers[i](new_args) # type: ignore[misc]
|
| 273 |
+
V.ops.store(
|
| 274 |
+
output_name,
|
| 275 |
+
output_index,
|
| 276 |
+
node.make_loader()(new_args).value,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
body = LoopBody(
|
| 280 |
+
fn,
|
| 281 |
+
(list(var_ranges.keys()), ()),
|
| 282 |
+
var_ranges,
|
| 283 |
+
list(var_ranges.keys()),
|
| 284 |
+
tuple(),
|
| 285 |
+
)
|
| 286 |
+
bodies.append(body)
|
| 287 |
+
var_sizes_list.append(var_sizes)
|
| 288 |
+
|
| 289 |
+
cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
|
| 290 |
+
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
|
| 291 |
+
return kernel_group.loops_code.getvalue()
|
| 292 |
+
|
| 293 |
+
def store_grouped_gemm_pointwise_nodes(
|
| 294 |
+
self,
|
| 295 |
+
dst: tuple[ir.Buffer],
|
| 296 |
+
nodes: list[ir.IRNode],
|
| 297 |
+
offsets: list[sympy.Expr],
|
| 298 |
+
reindexers: list[Optional[Callable[[list[Any]], list[Any]]]],
|
| 299 |
+
output_names: list[str],
|
| 300 |
+
) -> str:
|
| 301 |
+
ref_dst = dst[0]
|
| 302 |
+
var_sizes = (tuple(ref_dst.get_size()), ())
|
| 303 |
+
var_ranges = {
|
| 304 |
+
sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
|
| 305 |
+
for i, sz in enumerate(var_sizes[0])
|
| 306 |
+
}
|
| 307 |
+
assert offsets, "offsets should be set outside"
|
| 308 |
+
assert all(len(offset) == len(var_sizes[0]) for offset in offsets)
|
| 309 |
+
output_index = ref_dst.get_layout().make_indexer()([*var_ranges.keys()])
|
| 310 |
+
kernel_group = KernelGroup()
|
| 311 |
+
kernel_group.args = self.args
|
| 312 |
+
cpp_kernel_proxy = CppKernelProxy(kernel_group)
|
| 313 |
+
bodies = []
|
| 314 |
+
var_sizes_list = []
|
| 315 |
+
for i, node in enumerate(nodes):
|
| 316 |
+
output_name = output_names[i]
|
| 317 |
+
node = node.data if isinstance(node, ir.ComputedBuffer) else node
|
| 318 |
+
assert isinstance(node, ir.Pointwise), node
|
| 319 |
+
|
| 320 |
+
def fn(*args):
|
| 321 |
+
assert len(args) == 2
|
| 322 |
+
assert len(args[0]) == len(var_sizes[0])
|
| 323 |
+
assert len(args[1]) == 0
|
| 324 |
+
new_args = [arg + offset for arg, offset in zip(args[0], offsets[i])] # type: ignore[arg-type]
|
| 325 |
+
if reindexers[i] is not None:
|
| 326 |
+
new_args = reindexers[i](new_args) # type: ignore[misc]
|
| 327 |
+
V.ops.store(
|
| 328 |
+
output_name,
|
| 329 |
+
output_index,
|
| 330 |
+
node.make_loader()(new_args).value,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
body = LoopBody(
|
| 334 |
+
fn,
|
| 335 |
+
(list(var_ranges.keys()), ()),
|
| 336 |
+
var_ranges,
|
| 337 |
+
list(var_ranges.keys()),
|
| 338 |
+
tuple(),
|
| 339 |
+
)
|
| 340 |
+
bodies.append(body)
|
| 341 |
+
var_sizes_list.append(var_sizes)
|
| 342 |
+
|
| 343 |
+
cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
|
| 344 |
+
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
|
| 345 |
+
return kernel_group.loops_code.getvalue()
|
| 346 |
+
|
| 347 |
+
def store_output(
|
| 348 |
+
self,
|
| 349 |
+
dst: ir.Buffer,
|
| 350 |
+
src: ir.Buffer,
|
| 351 |
+
orig_src: Optional[ir.Buffer] = None,
|
| 352 |
+
epilogue_nodes: Optional[list[ir.IRNode]] = None,
|
| 353 |
+
offsets: Optional[list[Any]] = None,
|
| 354 |
+
reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None,
|
| 355 |
+
):
|
| 356 |
+
"""
|
| 357 |
+
Store the `src` buffer to the `dst` buffer. The size of `src` and `dst` should match.
|
| 358 |
+
If `epilogue_nodes` is provided, the `src` buffer is firstly computed with the epilogues
|
| 359 |
+
before stored to `dst`. The `epilogues_nodes` are all pointwise.
|
| 360 |
+
|
| 361 |
+
Notes:
|
| 362 |
+
1. `src` and `dst` buffer could be the same buffer in which case we are doing in-place compute
|
| 363 |
+
and stores. In case `epilogue_nodes` are not provided, we do nothing.
|
| 364 |
+
2. The `epilogue_nodes`, if exist, have computations on `src` before storing to `dst` but since
|
| 365 |
+
they come form the original Inductor IR, they might need to be adjusted before working with
|
| 366 |
+
`src` and `dst` as outlined below:
|
| 367 |
+
a) `src` or `dst` buffer could be a sub-slice of the ranges the `epilogue_nodes`work on.
|
| 368 |
+
In this case, the `offsets` could be provided to adjust the indices passed to
|
| 369 |
+
`epilogue_nodes` during codegen and the data ranges are also configured according to
|
| 370 |
+
the sizes of `src` and `dst`.
|
| 371 |
+
b) `dst` might be indexed in a different way as the `epilogue_nodes`, hence a `reindexer` is
|
| 372 |
+
needed on the indices to `epilogue_nodes` to match the indexing of `dst`.
|
| 373 |
+
c) If `src` is local, we need to add a local buffer for it and localize the `orig_src` buffer
|
| 374 |
+
in `epilogue_nodes` with `src`.
|
| 375 |
+
"""
|
| 376 |
+
assert isinstance(dst, (ir.Buffer, ir.ReinterpretView))
|
| 377 |
+
assert dst.get_size() == src.get_size(), f"{dst=}, {src=}"
|
| 378 |
+
if offsets:
|
| 379 |
+
offsets = parse_expr_with_index_symbols(offsets)
|
| 380 |
+
if epilogue_nodes:
|
| 381 |
+
with LocalBufferContext(self.args) as scope:
|
| 382 |
+
assert orig_src is not None
|
| 383 |
+
if orig_src.get_name() != src.get_name():
|
| 384 |
+
scope.add_local_buffer(
|
| 385 |
+
src,
|
| 386 |
+
[
|
| 387 |
+
orig_src,
|
| 388 |
+
],
|
| 389 |
+
)
|
| 390 |
+
epilogue_nodes = scope.localize_nodes(epilogue_nodes)
|
| 391 |
+
return self.store_pointwise_nodes(
|
| 392 |
+
dst,
|
| 393 |
+
epilogue_nodes, # type: ignore[arg-type]
|
| 394 |
+
offsets,
|
| 395 |
+
reindexers,
|
| 396 |
+
)
|
| 397 |
+
else:
|
| 398 |
+
if dst.get_name() != src.get_name():
|
| 399 |
+
# src is local
|
| 400 |
+
copy = L.copy(dst, src).data.data
|
| 401 |
+
with LocalBufferContext(self.args) as scope:
|
| 402 |
+
scope.add_local_buffer(src)
|
| 403 |
+
return self.store_pointwise_nodes(dst, [copy])
|
| 404 |
+
else:
|
| 405 |
+
assert dst.layout == src.layout, f"{dst=}, {src=}"
|
| 406 |
+
return ""
|
| 407 |
+
|
| 408 |
+
def store_outputs(
|
| 409 |
+
self,
|
| 410 |
+
dst: tuple[ir.Buffer],
|
| 411 |
+
src: tuple[ir.IRNode],
|
| 412 |
+
orig_src: Optional[tuple[ir.IRNode]] = None,
|
| 413 |
+
epilogue_nodes: Optional[list[ir.IRNode]] = None,
|
| 414 |
+
offsets: Optional[list[Any]] = None,
|
| 415 |
+
reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None,
|
| 416 |
+
multi_output_buffers: Optional[tuple[ir.MultiOutput]] = None,
|
| 417 |
+
):
|
| 418 |
+
assert isinstance(dst, Iterable)
|
| 419 |
+
assert all(_dst.get_size() == _src.get_size() for _src, _dst in zip(src, dst))
|
| 420 |
+
if offsets:
|
| 421 |
+
offsets = parse_expr_with_index_symbols(offsets)
|
| 422 |
+
gemm_num = len(src)
|
| 423 |
+
final_offsets = []
|
| 424 |
+
output_names = []
|
| 425 |
+
if epilogue_nodes:
|
| 426 |
+
if not reindexers:
|
| 427 |
+
reindexers = [None] * len(epilogue_nodes)
|
| 428 |
+
with LocalBufferContext(self.args) as scope:
|
| 429 |
+
assert orig_src is not None
|
| 430 |
+
localize_epilogue_nodes = []
|
| 431 |
+
all_read_names = []
|
| 432 |
+
for epilogue in epilogue_nodes:
|
| 433 |
+
all_read_names.extend(list(epilogue.get_read_names()))
|
| 434 |
+
localize_epilogue_nodes.extend(scope.localize_nodes(epilogue_nodes))
|
| 435 |
+
final_offsets.extend([offsets] * len(localize_epilogue_nodes))
|
| 436 |
+
output_names.extend(
|
| 437 |
+
[node.get_name() for node in localize_epilogue_nodes]
|
| 438 |
+
)
|
| 439 |
+
for gemm_idx in range(gemm_num):
|
| 440 |
+
if orig_src[gemm_idx].get_name() != src[gemm_idx].get_name():
|
| 441 |
+
if orig_src[gemm_idx].get_name() in all_read_names or (
|
| 442 |
+
multi_output_buffers
|
| 443 |
+
and multi_output_buffers[gemm_idx].get_name()
|
| 444 |
+
in all_read_names
|
| 445 |
+
):
|
| 446 |
+
# If any of the Epilogue nodes use this GEMM output, let's localize the GEMM output
|
| 447 |
+
global_buffers = [orig_src[gemm_idx]]
|
| 448 |
+
if (
|
| 449 |
+
multi_output_buffers
|
| 450 |
+
and multi_output_buffers[gemm_idx].get_name()
|
| 451 |
+
in all_read_names
|
| 452 |
+
and orig_src[gemm_idx].get_name() not in all_read_names
|
| 453 |
+
):
|
| 454 |
+
# Epilogue might directly read the MultiOutput, Locallize MultiOutput to the local Buffer
|
| 455 |
+
# if this MultiOutput has not been stored by in-template epilogue
|
| 456 |
+
# otherwise, use the cse store cache if it will be stored before used
|
| 457 |
+
global_buffers.append(multi_output_buffers[gemm_idx])
|
| 458 |
+
scope.add_local_buffer(
|
| 459 |
+
src[gemm_idx],
|
| 460 |
+
global_buffers,
|
| 461 |
+
)
|
| 462 |
+
else:
|
| 463 |
+
scope.add_local_buffer(src[gemm_idx])
|
| 464 |
+
localize_epilogue_nodes.extend(
|
| 465 |
+
[L.copy(dst[gemm_idx], src[gemm_idx]).data.data]
|
| 466 |
+
)
|
| 467 |
+
reindexers.append(None)
|
| 468 |
+
output_names.append(dst[gemm_idx].get_name())
|
| 469 |
+
final_offsets.append(
|
| 470 |
+
[sympy.S.Zero] * len(dst[gemm_idx].get_size())
|
| 471 |
+
)
|
| 472 |
+
res = self.store_grouped_gemm_pointwise_nodes(
|
| 473 |
+
dst,
|
| 474 |
+
localize_epilogue_nodes,
|
| 475 |
+
final_offsets,
|
| 476 |
+
reindexers,
|
| 477 |
+
output_names=output_names,
|
| 478 |
+
)
|
| 479 |
+
for gemm_idx in range(gemm_num):
|
| 480 |
+
if (
|
| 481 |
+
multi_output_buffers
|
| 482 |
+
and multi_output_buffers[gemm_idx].get_name() in all_read_names
|
| 483 |
+
):
|
| 484 |
+
# If the MultiOutput is used in the Epilogue, let's remove it from args
|
| 485 |
+
multi_output_name = multi_output_buffers[gemm_idx].get_name()
|
| 486 |
+
if (
|
| 487 |
+
multi_output_name in self.args.output_buffers
|
| 488 |
+
and self.args.output_buffers[multi_output_name]
|
| 489 |
+
is not REMOVED
|
| 490 |
+
):
|
| 491 |
+
self.remove_buffer(multi_output_name)
|
| 492 |
+
return res
|
| 493 |
+
else:
|
| 494 |
+
if dst[0].get_name() != src[0].get_name():
|
| 495 |
+
copy_list = []
|
| 496 |
+
with LocalBufferContext(self.args) as scope:
|
| 497 |
+
for _src, _dst in zip(src, dst):
|
| 498 |
+
copy_list.extend([L.copy(_dst, _src).data.data])
|
| 499 |
+
scope.add_local_buffer(_src)
|
| 500 |
+
output_names.append(_dst.get_name())
|
| 501 |
+
final_offsets.append([sympy.S.Zero] * len(_dst.get_size()))
|
| 502 |
+
reindexers = [None] * len(copy_list)
|
| 503 |
+
return self.store_grouped_gemm_pointwise_nodes(
|
| 504 |
+
dst,
|
| 505 |
+
nodes=copy_list,
|
| 506 |
+
offsets=final_offsets,
|
| 507 |
+
reindexers=reindexers,
|
| 508 |
+
output_names=output_names,
|
| 509 |
+
)
|
| 510 |
+
else:
|
| 511 |
+
assert all(
|
| 512 |
+
_src.get_name() == _dst.get_name() for _src, _dst in zip(src, dst)
|
| 513 |
+
)
|
| 514 |
+
assert all(
|
| 515 |
+
_src.get_layout() == _dst.get_layout()
|
| 516 |
+
for _src, _dst in zip(src, dst)
|
| 517 |
+
)
|
| 518 |
+
return ""
|
| 519 |
+
|
| 520 |
+
def check_bounds(self, expr, size, lower, upper):
|
| 521 |
+
# CppTemplateKernel does not need codegen related operations
|
| 522 |
+
return
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
class CppTemplateCaller(ir.ChoiceCaller):
|
| 526 |
+
"""
|
| 527 |
+
CppTemplateCaller
|
| 528 |
+
|
| 529 |
+
This class represents a caller for CPP template kernels. It is a subclass of ir.ChoiceCaller.
|
| 530 |
+
Attributes:
|
| 531 |
+
name (str): The name of the caller.
|
| 532 |
+
category (str): The category of the caller.
|
| 533 |
+
bmreq (CppBenchmarkRequest): The benchmark request for the caller.
|
| 534 |
+
template_buffer (ir.CppTemplateBuffer): The template buffer for the caller.
|
| 535 |
+
"""
|
| 536 |
+
|
| 537 |
+
def __init__(
|
| 538 |
+
self,
|
| 539 |
+
name: str,
|
| 540 |
+
category: str,
|
| 541 |
+
input_nodes: list[ir.Buffer],
|
| 542 |
+
layout: ir.Layout,
|
| 543 |
+
make_kernel_render: Callable[
|
| 544 |
+
[
|
| 545 |
+
ir.CppTemplateBuffer,
|
| 546 |
+
bool,
|
| 547 |
+
Optional[list[ir.IRNode]],
|
| 548 |
+
],
|
| 549 |
+
str,
|
| 550 |
+
],
|
| 551 |
+
bmreq: CppBenchmarkRequest,
|
| 552 |
+
template: "CppTemplate", # type: ignore[name-defined] # noqa: F821
|
| 553 |
+
info_kwargs: Optional[
|
| 554 |
+
dict[str, Union[ir.PrimitiveInfoType, list[ir.PrimitiveInfoType]]]
|
| 555 |
+
] = None,
|
| 556 |
+
):
|
| 557 |
+
super().__init__(name, input_nodes, layout, description="")
|
| 558 |
+
self.category = category
|
| 559 |
+
self.make_kernel_render = make_kernel_render
|
| 560 |
+
self.bmreq = bmreq
|
| 561 |
+
self.template = template
|
| 562 |
+
self.info_kwargs = info_kwargs
|
| 563 |
+
|
| 564 |
+
def precompile(self) -> None:
|
| 565 |
+
assert self.bmreq is not None
|
| 566 |
+
self.bmreq.precompile()
|
| 567 |
+
|
| 568 |
+
def benchmark(self, *args, out) -> float:
|
| 569 |
+
assert self.bmreq is not None
|
| 570 |
+
if config.profile_bandwidth_with_do_bench_using_profiling:
|
| 571 |
+
algo = self.bmreq.make_run_fn(*args, out=out)
|
| 572 |
+
return do_bench_using_profiling(algo)
|
| 573 |
+
return self.bmreq.benchmark(*args, out=out)
|
| 574 |
+
|
| 575 |
+
def hash_key(self) -> str:
|
| 576 |
+
return "-".join(
|
| 577 |
+
[
|
| 578 |
+
self.category,
|
| 579 |
+
self.bmreq.hash_key,
|
| 580 |
+
]
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
def info_dict(
|
| 584 |
+
self,
|
| 585 |
+
) -> dict[str, Union[ir.PrimitiveInfoType, list[ir.PrimitiveInfoType]]]:
|
| 586 |
+
return {"backend": "CPP", "op_type": "unknown"}
|
| 587 |
+
|
| 588 |
+
def output_node(self) -> ir.TensorBox:
|
| 589 |
+
return ir.TensorBox.create(
|
| 590 |
+
ir.CppTemplateBuffer(
|
| 591 |
+
layout=self.layout,
|
| 592 |
+
inputs=self.input_nodes,
|
| 593 |
+
make_kernel_render=self.make_kernel_render,
|
| 594 |
+
template=self.template,
|
| 595 |
+
choice=self,
|
| 596 |
+
)
|
| 597 |
+
)
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_utils.py
ADDED
|
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import contextlib
|
| 3 |
+
import dataclasses
|
| 4 |
+
import functools
|
| 5 |
+
import math
|
| 6 |
+
import sys
|
| 7 |
+
from collections import namedtuple
|
| 8 |
+
from collections.abc import Sequence
|
| 9 |
+
from typing import Any, Callable, Optional
|
| 10 |
+
from unittest.mock import patch
|
| 11 |
+
|
| 12 |
+
import sympy
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch._prims_common import is_integer_dtype
|
| 16 |
+
from torch.utils._ordered_set import OrderedSet
|
| 17 |
+
from torch.utils._sympy.printers import CppPrinter as _CppPrinter
|
| 18 |
+
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
| 19 |
+
from torch.utils._sympy.value_ranges import ValueRanges
|
| 20 |
+
|
| 21 |
+
from .. import ir
|
| 22 |
+
from ..dependencies import Dep
|
| 23 |
+
from ..loop_body import LoopBody
|
| 24 |
+
from ..scheduler import BaseSchedulerNode, SchedulerBuffer
|
| 25 |
+
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
|
| 26 |
+
from ..virtualized import ops, OpsValue, V
|
| 27 |
+
from .common import CSEVariable, Kernel, KernelArgs, OptimizationContext
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
DTYPE_TO_CPP = {
|
| 31 |
+
torch.float32: "float",
|
| 32 |
+
torch.float64: "double",
|
| 33 |
+
torch.float16: "at::Half",
|
| 34 |
+
torch.int64: "int64_t",
|
| 35 |
+
torch.int32: "int32_t",
|
| 36 |
+
torch.int16: "int16_t",
|
| 37 |
+
torch.int8: "int8_t",
|
| 38 |
+
torch.uint64: "uint64_t",
|
| 39 |
+
torch.uint32: "uint32_t",
|
| 40 |
+
torch.uint16: "uint16_t",
|
| 41 |
+
torch.uint8: "uint8_t",
|
| 42 |
+
torch.bool: "bool",
|
| 43 |
+
torch.bfloat16: "at::BFloat16",
|
| 44 |
+
torch.complex32: "at::complex<at::Half>",
|
| 45 |
+
torch.complex64: "at::complex<float>",
|
| 46 |
+
torch.complex128: "at::complex<double>",
|
| 47 |
+
torch.float8_e4m3fn: "at::Float8_e4m3fn",
|
| 48 |
+
torch.float8_e5m2: "at::Float8_e5m2",
|
| 49 |
+
torch.float8_e4m3fnuz: "at::Float8_e4m3fnuz",
|
| 50 |
+
torch.float8_e5m2fnuz: "at::Float8_e5m2fnuz",
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
DTYPE_TO_ATEN = {
|
| 54 |
+
torch.float32: "at::kFloat",
|
| 55 |
+
torch.float64: "at::kDouble",
|
| 56 |
+
torch.float16: "at::kHalf",
|
| 57 |
+
torch.int64: "at::kLong",
|
| 58 |
+
torch.int32: "at::kInt",
|
| 59 |
+
torch.int16: "at::kShort",
|
| 60 |
+
torch.int8: "at::kChar",
|
| 61 |
+
torch.uint64: "at::kUInt64",
|
| 62 |
+
torch.uint32: "at::kUInt32",
|
| 63 |
+
torch.uint16: "at::kUInt16",
|
| 64 |
+
torch.uint8: "at::kByte",
|
| 65 |
+
torch.uint32: "at::kUInt32",
|
| 66 |
+
torch.uint64: "at::kUInt64",
|
| 67 |
+
torch.bool: "at::kBool",
|
| 68 |
+
torch.bfloat16: "at::kBFloat16",
|
| 69 |
+
torch.complex32: "at::kComplexHalf",
|
| 70 |
+
torch.complex64: "at::kComplexFloat",
|
| 71 |
+
torch.complex128: "at::kComplexDouble",
|
| 72 |
+
torch.float8_e4m3fn: "at::kFloat8_e4m3fn",
|
| 73 |
+
torch.float8_e5m2: "at::kFloat8_e5m2",
|
| 74 |
+
torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz",
|
| 75 |
+
torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz",
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
DEVICE_TO_ATEN = {
|
| 79 |
+
"meta": "at::kMeta",
|
| 80 |
+
"cpu": "at::kCPU",
|
| 81 |
+
"cuda": "at::kCUDA",
|
| 82 |
+
"xpu": "at::kXPU",
|
| 83 |
+
"mps": "at::kMPS",
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
LAYOUT_TO_ATEN = {
|
| 87 |
+
torch.strided: "at::kStrided",
|
| 88 |
+
torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined]
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
# matches c10/core/DeviceType.h
|
| 92 |
+
DEVICE_TO_INT = {"cpu": 0, "cuda": 1}
|
| 93 |
+
|
| 94 |
+
_IS_WINDOWS = sys.platform == "win32"
|
| 95 |
+
|
| 96 |
+
INDEX_TYPE = "int64_t"
|
| 97 |
+
|
| 98 |
+
GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"])
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def get_promote_dtype(args):
|
| 102 |
+
return (
|
| 103 |
+
functools.reduce(
|
| 104 |
+
torch.promote_types, # type: ignore[arg-type]
|
| 105 |
+
[n.dtype for n in args if isinstance(n, CppCSEVariable)],
|
| 106 |
+
)
|
| 107 |
+
if all(n.dtype is not None for n in args if isinstance(n, CppCSEVariable))
|
| 108 |
+
else None # not enough info to calculate the promote dtype
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def promote_args(new_args):
|
| 113 |
+
def promote_arg(arg, promote_type):
|
| 114 |
+
if (
|
| 115 |
+
isinstance(arg, CppCSEVariable)
|
| 116 |
+
and arg.dtype
|
| 117 |
+
and promote_type
|
| 118 |
+
and arg.dtype != promote_type
|
| 119 |
+
):
|
| 120 |
+
arg = ops.to_dtype(arg, promote_type)
|
| 121 |
+
arg = arg.value if isinstance(arg, OpsValue) else arg
|
| 122 |
+
arg.dtype = promote_type
|
| 123 |
+
return arg
|
| 124 |
+
|
| 125 |
+
promote_type = get_promote_dtype(new_args)
|
| 126 |
+
promote_fn = functools.partial(
|
| 127 |
+
promote_arg,
|
| 128 |
+
promote_type=promote_type,
|
| 129 |
+
)
|
| 130 |
+
if (
|
| 131 |
+
all(
|
| 132 |
+
new_arg.dtype is not None
|
| 133 |
+
for new_arg in new_args
|
| 134 |
+
if isinstance(new_arg, CppCSEVariable)
|
| 135 |
+
)
|
| 136 |
+
and promote_type
|
| 137 |
+
):
|
| 138 |
+
new_args = list(map(promote_fn, new_args))
|
| 139 |
+
return new_args
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class CppCSEVariable(CSEVariable):
|
| 143 |
+
def __init__(
|
| 144 |
+
self,
|
| 145 |
+
name,
|
| 146 |
+
bounds: ValueRanges[Any],
|
| 147 |
+
dtype: Optional[torch.dtype] = None,
|
| 148 |
+
) -> None:
|
| 149 |
+
super().__init__(name, bounds, dtype)
|
| 150 |
+
self.is_vec = False
|
| 151 |
+
self.dependent_itervars = OrderedSet[sympy.Symbol]()
|
| 152 |
+
|
| 153 |
+
def __repr__(self) -> str:
|
| 154 |
+
return (
|
| 155 |
+
f"CppCSEVariable(name: {self.name}, bounds: {self.bounds}, is_vec: {self.is_vec}, dtype: {self.dtype}, "
|
| 156 |
+
f"dependent_itervars: {self.dependent_itervars})"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def update_on_args(self, name, args, kwargs):
|
| 160 |
+
if name == "load":
|
| 161 |
+
# args[2] is index
|
| 162 |
+
self._set_dependent_itervars(args[2])
|
| 163 |
+
else:
|
| 164 |
+
# propagate relevant itervars and is_vec from args
|
| 165 |
+
self.dependent_itervars.update(
|
| 166 |
+
*[
|
| 167 |
+
arg.dependent_itervars
|
| 168 |
+
for arg in args
|
| 169 |
+
if isinstance(arg, CppCSEVariable)
|
| 170 |
+
]
|
| 171 |
+
)
|
| 172 |
+
if name == "index_expr":
|
| 173 |
+
self._set_dependent_itervars(args[0])
|
| 174 |
+
if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)):
|
| 175 |
+
self.is_vec = True
|
| 176 |
+
|
| 177 |
+
def _set_dependent_itervars(self, index: sympy.Expr):
|
| 178 |
+
"""
|
| 179 |
+
Set the relevant itervars for this variable based on the `index` expression.
|
| 180 |
+
This includes the itervars directly used in the `index` as well as relevant itervars
|
| 181 |
+
of other cse variables used in the `index`.
|
| 182 |
+
"""
|
| 183 |
+
for s in index.free_symbols:
|
| 184 |
+
if s in V.kernel.itervars:
|
| 185 |
+
self.dependent_itervars.add(s) # type: ignore[arg-type]
|
| 186 |
+
elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined]
|
| 187 |
+
self.dependent_itervars.update(
|
| 188 |
+
V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined]
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
def depends_on(self, itervar: sympy.Symbol):
|
| 192 |
+
return itervar in self.dependent_itervars
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class CppPrinter(_CppPrinter):
|
| 196 |
+
def doprint(self, expr, *, simplify: bool = True, p=True):
|
| 197 |
+
# TODO: why are people passing strings to the printer here :think:
|
| 198 |
+
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
|
| 199 |
+
expr = V.graph.sizevars.simplify(expr)
|
| 200 |
+
return super().doprint(expr)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# A function to print, useful for printing sympy symbols.
|
| 204 |
+
cexpr = CppPrinter().doprint
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def cexpr_index(index):
|
| 208 |
+
return f"static_cast<{INDEX_TYPE}>({cexpr(index)})"
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def value_to_cpp(value, cpp_type):
|
| 212 |
+
if value == float("-inf"):
|
| 213 |
+
return f"-std::numeric_limits<{cpp_type}>::infinity()"
|
| 214 |
+
elif value == float("inf"):
|
| 215 |
+
return f"std::numeric_limits<{cpp_type}>::infinity()"
|
| 216 |
+
elif isinstance(value, bool):
|
| 217 |
+
return f"static_cast<{cpp_type}>({str(value).lower()})"
|
| 218 |
+
elif math.isnan(value):
|
| 219 |
+
return f"std::numeric_limits<{cpp_type}>::quiet_NaN()"
|
| 220 |
+
else:
|
| 221 |
+
return f"static_cast<{cpp_type}>({repr(value)})"
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def rewrite_index_for_function(
|
| 225 |
+
localize_buffer_handler: "LocalizeBufferHandler",
|
| 226 |
+
index: sympy.Expr,
|
| 227 |
+
global_buf_name: str,
|
| 228 |
+
):
|
| 229 |
+
# Local buffer at the inner dimensions
|
| 230 |
+
snode = V.graph.scheduler.name_to_buf[global_buf_name].defining_op
|
| 231 |
+
assert snode is not None
|
| 232 |
+
local_buf = localize_buffer_handler.global_to_local[global_buf_name]
|
| 233 |
+
scheduler_nodes = snode.get_nodes()
|
| 234 |
+
_, (group, reduction_group) = max(
|
| 235 |
+
scheduler_nodes, key=lambda x: int(x.is_reduction())
|
| 236 |
+
).group
|
| 237 |
+
call_ranges = tuple(group) + tuple(reduction_group)
|
| 238 |
+
indices_to_keep = [
|
| 239 |
+
f"x{len(call_ranges) - (idx + 1)}"
|
| 240 |
+
for idx in range(len(local_buf.get_layout().size))
|
| 241 |
+
]
|
| 242 |
+
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined]
|
| 243 |
+
replacements = {}
|
| 244 |
+
for x in sorted_symbols:
|
| 245 |
+
if x.name.startswith("x") and x.name not in indices_to_keep: # type: ignore[attr-defined]
|
| 246 |
+
# Only keep index used by local buffer
|
| 247 |
+
replacements[x] = sympy.core.numbers.Zero()
|
| 248 |
+
index = sympy_subs(index, replacements) # type: ignore[arg-type]
|
| 249 |
+
return index
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def rewrite_index_for_nodes(
|
| 253 |
+
localize_buffer_handler: "LocalizeBufferHandler",
|
| 254 |
+
index: sympy.Expr,
|
| 255 |
+
global_buf_name: str,
|
| 256 |
+
):
|
| 257 |
+
used_vars = OrderedSet(
|
| 258 |
+
s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)
|
| 259 |
+
)
|
| 260 |
+
index_vars = []
|
| 261 |
+
local_buf = localize_buffer_handler.global_to_local[global_buf_name]
|
| 262 |
+
for i in range(len(local_buf.get_size())):
|
| 263 |
+
var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
|
| 264 |
+
index_vars.append(var if var in used_vars else 0)
|
| 265 |
+
index = local_buf.get_layout().make_indexer()(index_vars)
|
| 266 |
+
return index
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
|
| 270 |
+
def __init__(
|
| 271 |
+
self,
|
| 272 |
+
inner,
|
| 273 |
+
global_to_local: dict[str, ir.Buffer],
|
| 274 |
+
rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr],
|
| 275 |
+
) -> None:
|
| 276 |
+
super().__init__(inner)
|
| 277 |
+
self.global_to_local = global_to_local
|
| 278 |
+
self.rewrite_index = rewrite_index
|
| 279 |
+
|
| 280 |
+
def localize(self, name: str, index: sympy.Expr):
|
| 281 |
+
if self.global_to_local and name in self.global_to_local:
|
| 282 |
+
assert self.rewrite_index is not None
|
| 283 |
+
index = self.rewrite_index(self, index, name)
|
| 284 |
+
name = self.global_to_local[name].get_name()
|
| 285 |
+
return name, index
|
| 286 |
+
|
| 287 |
+
def load(self, name: str, index: sympy.Expr):
|
| 288 |
+
return self._inner.load(*self.localize(name, index))
|
| 289 |
+
|
| 290 |
+
def store(self, name, index, value, mode=None):
|
| 291 |
+
local_buffer_name, local_buffer_index = self.localize(name, index)
|
| 292 |
+
res = self._inner.store(local_buffer_name, local_buffer_index, value, mode)
|
| 293 |
+
if (
|
| 294 |
+
self.global_to_local
|
| 295 |
+
and name in self.global_to_local
|
| 296 |
+
and isinstance(V.kernel, Kernel)
|
| 297 |
+
):
|
| 298 |
+
# Remove name of local buffer from Kernel.store_buffer_names
|
| 299 |
+
# local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store.
|
| 300 |
+
V.kernel.store_buffer_names.discard(local_buffer_name)
|
| 301 |
+
return res
|
| 302 |
+
|
| 303 |
+
def store_reduction(self, name, index, value):
|
| 304 |
+
return self._inner.store_reduction(*self.localize(name, index), value)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class LocalBufferContext:
|
| 308 |
+
"""
|
| 309 |
+
This class creates a context that helps to generate code involving Inductor IR with
|
| 310 |
+
function local buffers. These buffers are constructed during the codegen process and
|
| 311 |
+
are used to store intermediate results such as local accumulators. We do not want to
|
| 312 |
+
add them to `V.graph` since they are not global and we do not want to add them as
|
| 313 |
+
function arguments either. So we patch the codegen processes under this scope to support
|
| 314 |
+
these buffers without exposure to the outside world.
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
def __init__(self, kernel_args: KernelArgs) -> None:
|
| 318 |
+
self.kernel_args = kernel_args
|
| 319 |
+
self.exit_stack = contextlib.ExitStack()
|
| 320 |
+
# map local buffer name to local buffer
|
| 321 |
+
self.local_buffers: dict[str, ir.Buffer] = {}
|
| 322 |
+
# map global buffer name to global buffer
|
| 323 |
+
self.global_buffers: dict[str, ir.Buffer] = {}
|
| 324 |
+
# map global buffer name to local buffer
|
| 325 |
+
self.global_to_local: dict[str, ir.Buffer] = {}
|
| 326 |
+
# record the global buffers that are removed by this LocalBufferContext
|
| 327 |
+
self.removed_buffers: OrderedSet[str] = OrderedSet()
|
| 328 |
+
|
| 329 |
+
def __enter__(self):
|
| 330 |
+
self.exit_stack.__enter__()
|
| 331 |
+
original_get_dtype = V.graph.get_dtype
|
| 332 |
+
|
| 333 |
+
def get_dtype(name):
|
| 334 |
+
if name in self.local_buffers:
|
| 335 |
+
return self.local_buffers[name].get_dtype()
|
| 336 |
+
return original_get_dtype(name)
|
| 337 |
+
|
| 338 |
+
self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype))
|
| 339 |
+
|
| 340 |
+
original_input = self.kernel_args.input
|
| 341 |
+
|
| 342 |
+
def input(name):
|
| 343 |
+
if name in self.local_buffers:
|
| 344 |
+
return name
|
| 345 |
+
return original_input(name)
|
| 346 |
+
|
| 347 |
+
self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input))
|
| 348 |
+
|
| 349 |
+
original_output = self.kernel_args.output
|
| 350 |
+
|
| 351 |
+
def output(name):
|
| 352 |
+
if name in self.local_buffers:
|
| 353 |
+
return name
|
| 354 |
+
return original_output(name)
|
| 355 |
+
|
| 356 |
+
self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output))
|
| 357 |
+
|
| 358 |
+
# Set current LocalBufferContext into V
|
| 359 |
+
self.exit_stack.enter_context(V.set_local_buffer_context(self))
|
| 360 |
+
|
| 361 |
+
return self
|
| 362 |
+
|
| 363 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 364 |
+
self.local_buffers.clear()
|
| 365 |
+
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
|
| 366 |
+
|
| 367 |
+
def add_local_buffer(
|
| 368 |
+
self, local_buffer: ir.Buffer, global_buffers: Optional[list[ir.Buffer]] = None
|
| 369 |
+
):
|
| 370 |
+
assert local_buffer.get_name() not in self.local_buffers
|
| 371 |
+
self.local_buffers[local_buffer.get_name()] = local_buffer
|
| 372 |
+
if global_buffers:
|
| 373 |
+
for global_buffer in global_buffers:
|
| 374 |
+
global_buffer_name = global_buffer.get_name()
|
| 375 |
+
assert (
|
| 376 |
+
global_buffer_name not in self.global_buffers
|
| 377 |
+
and global_buffer_name not in self.global_to_local
|
| 378 |
+
)
|
| 379 |
+
self.global_buffers[global_buffer_name] = global_buffer
|
| 380 |
+
self.global_to_local[global_buffer_name] = local_buffer
|
| 381 |
+
if global_buffer_name not in V.graph.removed_buffers:
|
| 382 |
+
# Record the global buffers that are removed by this LocalBufferContext
|
| 383 |
+
# since which may need to restore. Refer to issue:
|
| 384 |
+
# https://github.com/pytorch/pytorch/issues/144186
|
| 385 |
+
self.removed_buffers.add(global_buffer_name)
|
| 386 |
+
V.graph.removed_buffers.add(global_buffer_name)
|
| 387 |
+
|
| 388 |
+
def localize_function(
|
| 389 |
+
self,
|
| 390 |
+
fn: Callable[..., Any],
|
| 391 |
+
rewrite_index: Callable[
|
| 392 |
+
["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr
|
| 393 |
+
] = rewrite_index_for_function,
|
| 394 |
+
):
|
| 395 |
+
def inner(*args, **kwargs):
|
| 396 |
+
with V.set_ops_handler(
|
| 397 |
+
LocalizeBufferHandler(
|
| 398 |
+
V.get_ops_handler(),
|
| 399 |
+
global_to_local=self.global_to_local,
|
| 400 |
+
rewrite_index=rewrite_index,
|
| 401 |
+
)
|
| 402 |
+
):
|
| 403 |
+
return fn(*args, **kwargs)
|
| 404 |
+
|
| 405 |
+
return inner
|
| 406 |
+
|
| 407 |
+
def localize_nodes(
|
| 408 |
+
self,
|
| 409 |
+
nodes: list[ir.IRNode],
|
| 410 |
+
rewrite_index: Callable[
|
| 411 |
+
["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr
|
| 412 |
+
] = rewrite_index_for_nodes,
|
| 413 |
+
) -> list[ir.IRNode]:
|
| 414 |
+
"""
|
| 415 |
+
Given `local_buf` and `global_buf` registered in current `LocalBufferContext`
|
| 416 |
+
though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf`
|
| 417 |
+
for the given `nodes` and returns a new list of IR nodes that work on `local_buf`
|
| 418 |
+
instead of `global_buf`, i.e., all the loads and stores are redirected to
|
| 419 |
+
`local_buf`. This helps the fused loops to work on smaller-sized local buffers
|
| 420 |
+
for better data locality.
|
| 421 |
+
|
| 422 |
+
The the data access of `local_buf` is assumed to be contiguous with the
|
| 423 |
+
same order as the `global_buf`.
|
| 424 |
+
"""
|
| 425 |
+
assert len(nodes) > 0
|
| 426 |
+
|
| 427 |
+
def wrap_inner_fn_for_node(node: ir.IRNode):
|
| 428 |
+
loops = node.data if isinstance(node, ir.ComputedBuffer) else node
|
| 429 |
+
assert isinstance(loops, ir.Loops)
|
| 430 |
+
new_inner_fn = self.localize_function(
|
| 431 |
+
loops.inner_fn,
|
| 432 |
+
rewrite_index,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
new_loops = dataclasses.replace(loops, inner_fn=new_inner_fn)
|
| 436 |
+
if isinstance(node, ir.ComputedBuffer):
|
| 437 |
+
new_node = ir.ComputedBuffer(
|
| 438 |
+
name=node.get_name(), layout=node.get_layout(), data=new_loops
|
| 439 |
+
)
|
| 440 |
+
else:
|
| 441 |
+
new_node = new_loops # type: ignore[assignment]
|
| 442 |
+
|
| 443 |
+
return new_node
|
| 444 |
+
|
| 445 |
+
return [wrap_inner_fn_for_node(node) for node in nodes]
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def unify_mask_base_type(
|
| 449 |
+
buffer: IndentedBuffer,
|
| 450 |
+
vars: tuple[CSEVariable, ...],
|
| 451 |
+
dtype=torch.float,
|
| 452 |
+
):
|
| 453 |
+
"""
|
| 454 |
+
Given list of cse variables,
|
| 455 |
+
Cast each to new mask base dtype and return casted cse variable.
|
| 456 |
+
"""
|
| 457 |
+
new_vars = (
|
| 458 |
+
V.kernel.cse.generate(
|
| 459 |
+
buffer,
|
| 460 |
+
f"{V.kernel._get_mask_cast(var, dtype)}",
|
| 461 |
+
)
|
| 462 |
+
for var in vars
|
| 463 |
+
)
|
| 464 |
+
return new_vars
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def may_unify_binary_op_mask_type(a, b):
|
| 468 |
+
"""
|
| 469 |
+
Given two cse variables, when dtype is bool, unify them to the same mask dtype and return casted cse variable.
|
| 470 |
+
"""
|
| 471 |
+
if a.dtype == torch.bool:
|
| 472 |
+
assert b.dtype == torch.bool
|
| 473 |
+
mask_dtype = torch.int32
|
| 474 |
+
return unify_mask_base_type(V.kernel.compute, (a, b), mask_dtype)
|
| 475 |
+
return a, b
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def codegen_rand(offset, code, rand_function, dst_dtype=torch.float32):
|
| 479 |
+
assert is_integer_dtype(offset.dtype)
|
| 480 |
+
code.writeline("[&]()")
|
| 481 |
+
with code.indent():
|
| 482 |
+
code.writeline(
|
| 483 |
+
f"{DTYPE_TO_CPP[offset.dtype]} offset[{V.kernel.tiling_factor}];"
|
| 484 |
+
)
|
| 485 |
+
code.writeline(f"{DTYPE_TO_CPP[dst_dtype]} result[{V.kernel.tiling_factor}];")
|
| 486 |
+
code.writeline(f"{offset}.store(offset);")
|
| 487 |
+
code.writeline(
|
| 488 |
+
f"for( {DTYPE_TO_CPP[offset.dtype]} offset_idx = 0; offset_idx < {V.kernel.tiling_factor}; offset_idx++ )"
|
| 489 |
+
)
|
| 490 |
+
with code.indent():
|
| 491 |
+
code.writeline(rand_function)
|
| 492 |
+
num_vectors = V.kernel._get_num_vectors(dtype=dst_dtype)
|
| 493 |
+
if num_vectors == 1:
|
| 494 |
+
code.writeline(
|
| 495 |
+
f"return at::vec::Vectorized<{DTYPE_TO_CPP[dst_dtype]}>::loadu(result);"
|
| 496 |
+
)
|
| 497 |
+
else:
|
| 498 |
+
code.writeline(
|
| 499 |
+
f"return at::vec::VectorizedN<{DTYPE_TO_CPP[dst_dtype]}, {num_vectors}>::loadu(result);"
|
| 500 |
+
)
|
| 501 |
+
code.writeline("()")
|
| 502 |
+
return code
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def get_gemm_template_output_and_compute_dtype(input_dtype):
|
| 506 |
+
if input_dtype in [torch.uint8, torch.int8]:
|
| 507 |
+
return (torch.int32, torch.int32)
|
| 508 |
+
else:
|
| 509 |
+
return (torch.float32, torch.float32)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def create_epilogue_with_attr(input_buffer, attr, **kwargs):
|
| 513 |
+
input_loader = input_buffer.make_loader()
|
| 514 |
+
dtype = input_buffer.get_dtype()
|
| 515 |
+
if attr == "relu":
|
| 516 |
+
|
| 517 |
+
def inner_fn(index):
|
| 518 |
+
input = input_loader(index)
|
| 519 |
+
zero = ops.constant(0, dtype)
|
| 520 |
+
return ops.maximum(input, zero)
|
| 521 |
+
|
| 522 |
+
elif attr == "gelu":
|
| 523 |
+
assert "algorithm" in kwargs
|
| 524 |
+
if kwargs["algorithm"] == "none":
|
| 525 |
+
|
| 526 |
+
def inner_fn(index):
|
| 527 |
+
input = input_loader(index)
|
| 528 |
+
if dtype != torch.float:
|
| 529 |
+
input = ops.to_dtype(input, torch.float)
|
| 530 |
+
half = ops.constant(0.5, torch.float)
|
| 531 |
+
one = ops.constant(1.0, torch.float)
|
| 532 |
+
const = ops.constant(0.7071067811865476, torch.float)
|
| 533 |
+
result = input * half * (ops.erf(input * const) + one)
|
| 534 |
+
if dtype != torch.float:
|
| 535 |
+
result = ops.to_dtype(result, dtype)
|
| 536 |
+
return result
|
| 537 |
+
|
| 538 |
+
else:
|
| 539 |
+
assert kwargs["algorithm"] == "tanh"
|
| 540 |
+
|
| 541 |
+
def inner_fn(index):
|
| 542 |
+
input = input_loader(index)
|
| 543 |
+
if dtype != torch.float:
|
| 544 |
+
input = ops.to_dtype(input, torch.float)
|
| 545 |
+
half = ops.constant(0.5, torch.float)
|
| 546 |
+
one = ops.constant(1.0, torch.float)
|
| 547 |
+
const1 = ops.constant(0.7978845608028654, torch.float)
|
| 548 |
+
const2 = ops.constant(0.044715, torch.float)
|
| 549 |
+
result = (
|
| 550 |
+
half
|
| 551 |
+
* input
|
| 552 |
+
* (
|
| 553 |
+
one
|
| 554 |
+
+ ops.tanh(const1 * (input + const2 * input * input * input))
|
| 555 |
+
)
|
| 556 |
+
)
|
| 557 |
+
if dtype != torch.float:
|
| 558 |
+
result = ops.to_dtype(result, dtype)
|
| 559 |
+
return result
|
| 560 |
+
|
| 561 |
+
elif attr == "swish":
|
| 562 |
+
|
| 563 |
+
def inner_fn(index):
|
| 564 |
+
input = input_loader(index)
|
| 565 |
+
result = input * ops.sigmoid(input)
|
| 566 |
+
return result
|
| 567 |
+
|
| 568 |
+
elif attr == "sigmoid":
|
| 569 |
+
|
| 570 |
+
def inner_fn(index):
|
| 571 |
+
return ops.sigmoid(input_loader(index))
|
| 572 |
+
|
| 573 |
+
elif attr == "tanh":
|
| 574 |
+
|
| 575 |
+
def inner_fn(index):
|
| 576 |
+
return ops.tanh(input_loader(index))
|
| 577 |
+
|
| 578 |
+
elif attr == "hardswish" or attr == "hardsigmoid":
|
| 579 |
+
|
| 580 |
+
def hardsigmoid_float(input):
|
| 581 |
+
zero = ops.constant(0, torch.float)
|
| 582 |
+
six = ops.constant(6, torch.float)
|
| 583 |
+
three = ops.constant(3, torch.float)
|
| 584 |
+
one_over_six = ops.constant(0.16666666666666666, torch.float)
|
| 585 |
+
max = ops.maximum(input + three, zero)
|
| 586 |
+
min = ops.minimum(max, six)
|
| 587 |
+
return min * one_over_six
|
| 588 |
+
|
| 589 |
+
def inner_fn(index):
|
| 590 |
+
input = input_loader(index)
|
| 591 |
+
if dtype != torch.float:
|
| 592 |
+
input = ops.to_dtype(input, torch.float)
|
| 593 |
+
result = hardsigmoid_float(input)
|
| 594 |
+
if attr == "hardswish":
|
| 595 |
+
result = input * result
|
| 596 |
+
if dtype != torch.float:
|
| 597 |
+
result = ops.to_dtype(result, dtype)
|
| 598 |
+
return result
|
| 599 |
+
|
| 600 |
+
elif attr == "leaky_relu":
|
| 601 |
+
assert "scalars" in kwargs
|
| 602 |
+
assert len(kwargs["scalars"]) == 1
|
| 603 |
+
negative_slope = kwargs["scalars"][0]
|
| 604 |
+
|
| 605 |
+
def inner_fn(index):
|
| 606 |
+
input = input_loader(index)
|
| 607 |
+
if dtype != torch.float:
|
| 608 |
+
input = ops.to_dtype(input, torch.float)
|
| 609 |
+
zero = ops.constant(0, torch.float)
|
| 610 |
+
result = ops.where(
|
| 611 |
+
input > zero, input, input * ops.constant(negative_slope, torch.float)
|
| 612 |
+
)
|
| 613 |
+
if dtype != torch.float:
|
| 614 |
+
result = ops.to_dtype(result, dtype)
|
| 615 |
+
return result
|
| 616 |
+
|
| 617 |
+
elif attr == "hardtanh":
|
| 618 |
+
assert "scalars" in kwargs
|
| 619 |
+
assert len(kwargs["scalars"]) == 2
|
| 620 |
+
min_value = kwargs["scalars"][0]
|
| 621 |
+
max_value = kwargs["scalars"][1]
|
| 622 |
+
|
| 623 |
+
def inner_fn(index):
|
| 624 |
+
input = input_loader(index)
|
| 625 |
+
if dtype != torch.float:
|
| 626 |
+
input = ops.to_dtype(input, torch.float)
|
| 627 |
+
result = ops.minimum(
|
| 628 |
+
ops.maximum(input, ops.constant(min_value, torch.float)),
|
| 629 |
+
ops.constant(max_value, torch.float),
|
| 630 |
+
)
|
| 631 |
+
if dtype != torch.float:
|
| 632 |
+
result = ops.to_dtype(result, dtype)
|
| 633 |
+
return result
|
| 634 |
+
|
| 635 |
+
elif attr in ["add", "sub", "mul"]:
|
| 636 |
+
assert "other" in kwargs
|
| 637 |
+
other = kwargs["other"]
|
| 638 |
+
num_input_dims = len(input_buffer.get_size())
|
| 639 |
+
num_other_dims = len(other.get_size())
|
| 640 |
+
dims_diff = num_input_dims - num_other_dims
|
| 641 |
+
other_loader = other.make_loader()
|
| 642 |
+
|
| 643 |
+
def inner_fn(index):
|
| 644 |
+
op = getattr(ops, attr)
|
| 645 |
+
if dims_diff != 0:
|
| 646 |
+
return op(input_loader(index), other_loader(index[dims_diff:]))
|
| 647 |
+
else:
|
| 648 |
+
return op(input_loader(index), other_loader(index))
|
| 649 |
+
|
| 650 |
+
elif attr == "bias_add":
|
| 651 |
+
assert "other" in kwargs
|
| 652 |
+
assert "beta" in kwargs
|
| 653 |
+
assert "dtype" in kwargs
|
| 654 |
+
beta = kwargs["beta"]
|
| 655 |
+
other = kwargs["other"]
|
| 656 |
+
dtype = kwargs["dtype"]
|
| 657 |
+
bias_loader = other.make_loader()
|
| 658 |
+
|
| 659 |
+
def inner_fn(index):
|
| 660 |
+
bias = bias_loader(index)
|
| 661 |
+
input = input_loader(index)
|
| 662 |
+
if beta != 1:
|
| 663 |
+
result = ops.constant(beta, torch.float) * bias + input
|
| 664 |
+
else:
|
| 665 |
+
result = bias + input
|
| 666 |
+
return result
|
| 667 |
+
|
| 668 |
+
else:
|
| 669 |
+
raise ValueError(f"Unsupported epilogue attribute: {attr}")
|
| 670 |
+
return ir.Pointwise(
|
| 671 |
+
device=input_buffer.get_device(),
|
| 672 |
+
dtype=dtype,
|
| 673 |
+
inner_fn=inner_fn,
|
| 674 |
+
ranges=input_buffer.get_size(),
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
def _get_loop_body(fn_list):
|
| 679 |
+
if all(isinstance(fn, LoopBody) for fn in fn_list):
|
| 680 |
+
loop_bodies = fn_list
|
| 681 |
+
else:
|
| 682 |
+
if hasattr(fn_list[0], "original_fn"):
|
| 683 |
+
# For the case of local buffer, we wrap the fn with localize_function
|
| 684 |
+
assert all(hasattr(fn, "original_fn") for fn in fn_list)
|
| 685 |
+
assert all(
|
| 686 |
+
isinstance(fn.original_fn.args[0]._body, LoopBody) for fn in fn_list
|
| 687 |
+
)
|
| 688 |
+
loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list]
|
| 689 |
+
else:
|
| 690 |
+
assert all(isinstance(fn, functools.partial) for fn in fn_list)
|
| 691 |
+
assert all(isinstance(fn.args[0]._body, LoopBody) for fn in fn_list)
|
| 692 |
+
loop_bodies = [fn.args[0]._body for fn in fn_list]
|
| 693 |
+
assert loop_bodies is not None
|
| 694 |
+
return loop_bodies
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def _get_dtype_from_loopbodies(loop_bodies):
|
| 698 |
+
dtypes = OrderedSet[torch.dtype]()
|
| 699 |
+
for loop_body in loop_bodies:
|
| 700 |
+
graphs = [loop_body.root_block.graph] + [
|
| 701 |
+
body.graph for body in list(loop_body.subblocks.values())
|
| 702 |
+
]
|
| 703 |
+
for graph in graphs:
|
| 704 |
+
for node in graph.nodes:
|
| 705 |
+
if node.op != "call_method":
|
| 706 |
+
continue
|
| 707 |
+
dtypes.add(node.meta[OptimizationContext.key].dtype)
|
| 708 |
+
return dtypes
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
def template_fusion_with_epilogues_supported(
|
| 712 |
+
template: BaseSchedulerNode, epilogues: list[BaseSchedulerNode]
|
| 713 |
+
) -> tuple[bool, bool]:
|
| 714 |
+
def _get_indexes_of_template_buf_read(
|
| 715 |
+
epilogue_node: ir.Operation, template_buf_names: list[str]
|
| 716 |
+
) -> list[sympy.Expr]:
|
| 717 |
+
return [
|
| 718 |
+
read.index
|
| 719 |
+
for read in epilogue_node.get_reads()
|
| 720 |
+
if read.name in template_buf_names
|
| 721 |
+
]
|
| 722 |
+
|
| 723 |
+
def _check_supported_and_same_indexes(
|
| 724 |
+
index_of_template_buf_read: Sequence[sympy.Expr],
|
| 725 |
+
epilogue_writes: OrderedSet[Dep],
|
| 726 |
+
) -> tuple[bool, bool]:
|
| 727 |
+
num_indexes = len(OrderedSet(index_of_template_buf_read))
|
| 728 |
+
|
| 729 |
+
if num_indexes > 1:
|
| 730 |
+
same_index = False
|
| 731 |
+
supported = False # Different read indexes not supported
|
| 732 |
+
elif num_indexes == 0:
|
| 733 |
+
same_index = True
|
| 734 |
+
supported = True # No reads, automatically supported
|
| 735 |
+
elif num_indexes == 1:
|
| 736 |
+
iotbr = index_of_template_buf_read[0]
|
| 737 |
+
same_index = all(write.index == iotbr for write in epilogue_writes)
|
| 738 |
+
# TODO: Add support of fusion when the read of template buffer and the write of epilogue output
|
| 739 |
+
# in the epilogue node don't have the same index and change supported to True
|
| 740 |
+
supported = same_index
|
| 741 |
+
else:
|
| 742 |
+
raise AssertionError("Should not reach here")
|
| 743 |
+
|
| 744 |
+
return supported, same_index
|
| 745 |
+
|
| 746 |
+
def _template_fusion_supported(
|
| 747 |
+
template_outputs: Sequence[SchedulerBuffer], epilogue_nodes: list[ir.Operation]
|
| 748 |
+
) -> tuple[bool, bool]:
|
| 749 |
+
template_buf_names = [x.get_name() for x in template_outputs]
|
| 750 |
+
indexes_of_template_buf_reads = [
|
| 751 |
+
_get_indexes_of_template_buf_read(epilogue_node, template_buf_names)
|
| 752 |
+
for epilogue_node in epilogue_nodes
|
| 753 |
+
]
|
| 754 |
+
epilogue_nodes_writes = [
|
| 755 |
+
epilogue_node.get_read_writes().writes for epilogue_node in epilogue_nodes
|
| 756 |
+
]
|
| 757 |
+
|
| 758 |
+
results = [
|
| 759 |
+
_check_supported_and_same_indexes(reads, writes)
|
| 760 |
+
for reads, writes in zip(
|
| 761 |
+
indexes_of_template_buf_reads, epilogue_nodes_writes
|
| 762 |
+
)
|
| 763 |
+
]
|
| 764 |
+
supported, same_indexes = zip(*results)
|
| 765 |
+
return all(supported), all(same_indexes)
|
| 766 |
+
|
| 767 |
+
assert template.is_template()
|
| 768 |
+
template_outputs = template.get_outputs()
|
| 769 |
+
|
| 770 |
+
epilogue_nodes = [
|
| 771 |
+
n.node
|
| 772 |
+
for epilogue in epilogues
|
| 773 |
+
for n in epilogue.get_nodes()
|
| 774 |
+
if n.node is not None
|
| 775 |
+
]
|
| 776 |
+
return _template_fusion_supported(template_outputs, epilogue_nodes)
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py
ADDED
|
@@ -0,0 +1,878 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from collections.abc import Sequence
|
| 3 |
+
from typing import Any, Callable, Optional, Union
|
| 4 |
+
|
| 5 |
+
import sympy
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
| 9 |
+
import torch._ops
|
| 10 |
+
|
| 11 |
+
from .. import config, ir
|
| 12 |
+
from ..utils import sympy_product
|
| 13 |
+
from ..virtualized import V
|
| 14 |
+
from .cpp_utils import DTYPE_TO_CPP
|
| 15 |
+
from .cpp_wrapper_cpu import CppWrapperCpu
|
| 16 |
+
from .wrapper import (
|
| 17 |
+
BufferLike,
|
| 18 |
+
EnterSubgraphLine,
|
| 19 |
+
ExitSubgraphLine,
|
| 20 |
+
MemoryPlanningLine,
|
| 21 |
+
MemoryPlanningState,
|
| 22 |
+
PythonWrapperCodegen,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
BufferName = str
|
| 27 |
+
|
| 28 |
+
# Default thread stack sizes vary by platform:
|
| 29 |
+
# - Linux: 8 MB
|
| 30 |
+
# - macOS: 512 KB
|
| 31 |
+
# - Windows: 1 MB
|
| 32 |
+
# Just pick something comfortably smaller than the smallest for now.
|
| 33 |
+
MAX_STACK_ALLOCATION_SIZE = 1024 * 100
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class CppWrapperCpuArrayRef(CppWrapperCpu):
|
| 37 |
+
"""
|
| 38 |
+
Generates cpp wrapper for running on CPU and calls cpp kernels
|
| 39 |
+
|
| 40 |
+
This class is forked from CppWrapperCpu, with a difference that tensors may be
|
| 41 |
+
represented as ArrayRef, see torch/csrc/inductor/aoti_runtime/arrayref_tensor.h
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self):
|
| 45 |
+
super().__init__()
|
| 46 |
+
assert self.device == "cpu", "ArrayRefTensor only supported on CPU!"
|
| 47 |
+
self.allow_stack_allocation = config.aot_inductor.allow_stack_allocation
|
| 48 |
+
self.stack_allocated_buffers: dict[BufferName, BufferLike] = {}
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def create(
|
| 52 |
+
is_subgraph: bool,
|
| 53 |
+
subgraph_name: Optional[str],
|
| 54 |
+
parent_wrapper: Optional[PythonWrapperCodegen],
|
| 55 |
+
partition_signatures: Optional[ir.GraphPartitionSignature] = None,
|
| 56 |
+
):
|
| 57 |
+
# TODO - support subgraph codegen by lifting functions. Check the
|
| 58 |
+
# comment at CppWrapperCpu `codegen_subgraph` function.
|
| 59 |
+
return CppWrapperCpuArrayRef()
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def get_input_cpp_type(input):
|
| 63 |
+
assert config.aot_inductor.use_minimal_arrayref_interface
|
| 64 |
+
|
| 65 |
+
if isinstance(input, sympy.Expr):
|
| 66 |
+
from ..graph import may_get_constant_buffer_dtype
|
| 67 |
+
|
| 68 |
+
dtype = may_get_constant_buffer_dtype(input)
|
| 69 |
+
assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}"
|
| 70 |
+
return DTYPE_TO_CPP[dtype]
|
| 71 |
+
return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>"
|
| 72 |
+
|
| 73 |
+
@staticmethod
|
| 74 |
+
def get_device_include_path(device: str) -> str:
|
| 75 |
+
assert device == "cpu", "ArrayRef only supported on CPU!"
|
| 76 |
+
if V.graph.aot_mode:
|
| 77 |
+
return "#include <torch/csrc/inductor/aoti_include/array_ref.h>"
|
| 78 |
+
return "#include <torch/csrc/inductor/cpp_wrapper/array_ref.h>"
|
| 79 |
+
|
| 80 |
+
def codegen_input_numel_asserts(self):
|
| 81 |
+
for name, buf in V.graph.graph_inputs.items():
|
| 82 |
+
if isinstance(buf, sympy.Expr):
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
# comparing strides for 0 size tensor is tricky. Ignore them for now.
|
| 86 |
+
if sympy_product(buf.get_size()) == 0:
|
| 87 |
+
continue
|
| 88 |
+
numel = buf.get_numel()
|
| 89 |
+
self.prefix.writeline(f"assert_numel({name}, {numel});")
|
| 90 |
+
|
| 91 |
+
def generate_extern_kernel_alloc(self, *args, **kwargs):
|
| 92 |
+
# Disable stack allocation for extern kernels.
|
| 93 |
+
self.allow_stack_allocation = False
|
| 94 |
+
super().generate_extern_kernel_alloc(*args, **kwargs)
|
| 95 |
+
|
| 96 |
+
def generate_extern_kernel_out(self, *args, **kwargs):
|
| 97 |
+
# Disable stack allocation for extern kernels.
|
| 98 |
+
self.allow_stack_allocation = False
|
| 99 |
+
super().generate_extern_kernel_out(*args, **kwargs)
|
| 100 |
+
|
| 101 |
+
def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None:
|
| 102 |
+
# Disable stack allocation for extern kernels.
|
| 103 |
+
self.allow_stack_allocation = False
|
| 104 |
+
super().generate_fallback_kernel(node)
|
| 105 |
+
|
| 106 |
+
def _generate_kernel_call_helper(
|
| 107 |
+
self,
|
| 108 |
+
kernel_name: str,
|
| 109 |
+
call_args,
|
| 110 |
+
*,
|
| 111 |
+
device=None,
|
| 112 |
+
triton=True,
|
| 113 |
+
arg_types=None,
|
| 114 |
+
raw_keys=None,
|
| 115 |
+
raw_args=None,
|
| 116 |
+
triton_meta=None,
|
| 117 |
+
graph_name="",
|
| 118 |
+
original_fxnode_name=None,
|
| 119 |
+
):
|
| 120 |
+
"""
|
| 121 |
+
Generates kernel call code.
|
| 122 |
+
|
| 123 |
+
triton: Defines whether the GPU backend uses Triton for codegen.
|
| 124 |
+
Otherwise it uses the CUDA language for codegen.
|
| 125 |
+
Only valid when cuda == True.
|
| 126 |
+
"""
|
| 127 |
+
assert not triton, (
|
| 128 |
+
"CppWrapperCpuArrayRef.generate_kernel_call does not support GPU"
|
| 129 |
+
)
|
| 130 |
+
assert arg_types is not None and len(call_args) == len(arg_types), (
|
| 131 |
+
"Mismatch call_args and arg_types in generate_kernel_call"
|
| 132 |
+
)
|
| 133 |
+
new_args = []
|
| 134 |
+
for idx, arg in enumerate(call_args):
|
| 135 |
+
if "*" in arg_types[idx]:
|
| 136 |
+
var_name = f"var_{next(self.arg_var_id)}"
|
| 137 |
+
self.writeline(f"auto* {var_name} = get_data_ptr_wrapper({arg});")
|
| 138 |
+
new_args.append(f"({arg_types[idx]})({var_name})")
|
| 139 |
+
else:
|
| 140 |
+
# arg is a scalar
|
| 141 |
+
new_args.append(arg)
|
| 142 |
+
# debug printer related logic for cpp kernel type.
|
| 143 |
+
debug_printer_manager = V.graph.wrapper_code.debug_printer
|
| 144 |
+
debug_printer_manager.set_printer_args(
|
| 145 |
+
call_args,
|
| 146 |
+
kernel_name,
|
| 147 |
+
None,
|
| 148 |
+
None,
|
| 149 |
+
"cpp",
|
| 150 |
+
)
|
| 151 |
+
with debug_printer_manager:
|
| 152 |
+
self.writeline(self.wrap_kernel_call(kernel_name, new_args))
|
| 153 |
+
|
| 154 |
+
def write_wrapper_decl(self):
|
| 155 |
+
inputs_len = len(V.graph.graph_inputs.keys())
|
| 156 |
+
if V.graph.aot_mode:
|
| 157 |
+
if (
|
| 158 |
+
config.aot_inductor.use_minimal_arrayref_interface
|
| 159 |
+
and not V.graph.is_const_graph
|
| 160 |
+
):
|
| 161 |
+
input_cpp_types = ", ".join(
|
| 162 |
+
f"{CppWrapperCpuArrayRef.get_input_cpp_type(x)}"
|
| 163 |
+
for x in V.graph.graph_inputs.values()
|
| 164 |
+
)
|
| 165 |
+
output_arrayref_types = ", ".join(
|
| 166 |
+
f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>"
|
| 167 |
+
for x in V.graph.graph_outputs
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self.prefix.splice(
|
| 171 |
+
f"""
|
| 172 |
+
using AOTInductorModelInputs = std::tuple<{input_cpp_types}>;
|
| 173 |
+
using AOTInductorModelOutputs = std::tuple<{output_arrayref_types}>;
|
| 174 |
+
"""
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
if V.graph.const_module:
|
| 178 |
+
self.header.splice(V.graph.const_module.wrapper_code.header)
|
| 179 |
+
|
| 180 |
+
assert V.graph.const_wrapper_code is not None
|
| 181 |
+
self.prefix.splice(V.graph.const_wrapper_code)
|
| 182 |
+
|
| 183 |
+
assert V.graph.const_kernel_code is not None
|
| 184 |
+
self.kernel_declarations.splice(V.graph.const_kernel_code)
|
| 185 |
+
|
| 186 |
+
if V.graph.is_const_graph:
|
| 187 |
+
self.prefix.splice(
|
| 188 |
+
"""
|
| 189 |
+
void AOTInductorModel::_const_run_impl(
|
| 190 |
+
std::vector<AtenTensorHandle>& output_handles,
|
| 191 |
+
DeviceStreamType stream,
|
| 192 |
+
AOTIProxyExecutorHandle proxy_executor
|
| 193 |
+
) {
|
| 194 |
+
"""
|
| 195 |
+
)
|
| 196 |
+
else:
|
| 197 |
+
if not config.aot_inductor.use_runtime_constant_folding:
|
| 198 |
+
# If we do not split the constant graph, we'll just create
|
| 199 |
+
# an empty implementation when wrapping the main module.
|
| 200 |
+
self.prefix.splice(
|
| 201 |
+
"""
|
| 202 |
+
void AOTInductorModel::_const_run_impl(
|
| 203 |
+
std::vector<AtenTensorHandle>& output_handles,
|
| 204 |
+
DeviceStreamType stream,
|
| 205 |
+
AOTIProxyExecutorHandle proxy_executor
|
| 206 |
+
) {}
|
| 207 |
+
|
| 208 |
+
"""
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
run_impl_proto = """
|
| 212 |
+
void AOTInductorModel::run_impl(
|
| 213 |
+
AtenTensorHandle*
|
| 214 |
+
input_handles, // array of input AtenTensorHandle; handles
|
| 215 |
+
// are stolen; the array itself is borrowed
|
| 216 |
+
AtenTensorHandle*
|
| 217 |
+
output_handles, // array for writing output AtenTensorHandle; handles
|
| 218 |
+
// will be stolen by the caller; the array itself is
|
| 219 |
+
// borrowed
|
| 220 |
+
DeviceStreamType stream,
|
| 221 |
+
AOTIProxyExecutorHandle proxy_executor
|
| 222 |
+
) {
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
self.generate_input_output_runtime_checks()
|
| 226 |
+
run_impl_proto += """
|
| 227 |
+
__check_inputs_outputs(input_handles, output_handles);
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
if config.aot_inductor.use_minimal_arrayref_interface:
|
| 231 |
+
self.prefix.splice(
|
| 232 |
+
"""
|
| 233 |
+
template <>
|
| 234 |
+
AOTInductorModelOutputs AOTInductorModel::run_impl_minimal_arrayref_interface<
|
| 235 |
+
AOTInductorModelInputs, AOTInductorModelOutputs>(
|
| 236 |
+
const AOTInductorModelInputs& inputs,
|
| 237 |
+
DeviceStreamType stream,
|
| 238 |
+
AOTIProxyExecutorHandle proxy_executor
|
| 239 |
+
) {
|
| 240 |
+
"""
|
| 241 |
+
)
|
| 242 |
+
self.suffix.splice(run_impl_proto)
|
| 243 |
+
self.suffix.splice(
|
| 244 |
+
"""
|
| 245 |
+
AOTInductorModelInputs inputs;
|
| 246 |
+
convert_handles_to_inputs(input_handles, inputs);
|
| 247 |
+
auto outputs = run_impl_minimal_arrayref_interface<AOTInductorModelInputs, AOTInductorModelOutputs>(
|
| 248 |
+
inputs, stream, proxy_executor);
|
| 249 |
+
// NOTE: outputs is full of ArrayRef to thread_local storage. If in the future we need this
|
| 250 |
+
// interface to perform well for a DSO using the minimal arrayref interface, all we need
|
| 251 |
+
// to do is provide ThreadLocalCachedTensor for each one!
|
| 252 |
+
convert_outputs_to_handles(outputs, output_handles);
|
| 253 |
+
}
|
| 254 |
+
"""
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
self.suffix.splice(
|
| 258 |
+
"""
|
| 259 |
+
extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterface(
|
| 260 |
+
AOTInductorModelHandle model_handle,
|
| 261 |
+
const AOTInductorModelInputs& inputs,
|
| 262 |
+
AOTInductorModelOutputs& outputs) {
|
| 263 |
+
auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
|
| 264 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 265 |
+
outputs = model->run_impl_minimal_arrayref_interface<AOTInductorModelInputs, AOTInductorModelOutputs>(
|
| 266 |
+
inputs,
|
| 267 |
+
(torch::aot_inductor::DeviceStreamType)nullptr,
|
| 268 |
+
nullptr);
|
| 269 |
+
})
|
| 270 |
+
}
|
| 271 |
+
"""
|
| 272 |
+
)
|
| 273 |
+
else:
|
| 274 |
+
self.prefix.splice(run_impl_proto)
|
| 275 |
+
else:
|
| 276 |
+
# cpp entry function for JIT with cpp wrapper
|
| 277 |
+
self.prefix.splice(
|
| 278 |
+
"""
|
| 279 |
+
void inductor_entry_impl(
|
| 280 |
+
AtenTensorHandle*
|
| 281 |
+
input_handles, // array of input AtenTensorHandle; handles
|
| 282 |
+
// are stolen; the array itself is borrowed
|
| 283 |
+
AtenTensorHandle*
|
| 284 |
+
output_handles // array for writing output AtenTensorHandle; handles
|
| 285 |
+
// will be stolen by the caller; the array itself is
|
| 286 |
+
// borrowed)
|
| 287 |
+
) {
|
| 288 |
+
"""
|
| 289 |
+
)
|
| 290 |
+
with self.prefix.indent():
|
| 291 |
+
# assign inputs and outputs in both cases so the later codegen can be simplified
|
| 292 |
+
if not config.aot_inductor.use_minimal_arrayref_interface:
|
| 293 |
+
if not V.graph.is_const_graph:
|
| 294 |
+
if V.graph.aot_mode:
|
| 295 |
+
num_args = len(V.graph.graph_inputs)
|
| 296 |
+
else:
|
| 297 |
+
# Weights are promoted in the JIT mode
|
| 298 |
+
num_args = len(V.graph.graph_inputs) + len(V.graph.constants)
|
| 299 |
+
# release GIL to support multiple instances inference (in different threads of the same process)
|
| 300 |
+
self.prefix.splice("py::gil_scoped_release release;")
|
| 301 |
+
|
| 302 |
+
self.prefix.splice(
|
| 303 |
+
f"""
|
| 304 |
+
auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args});
|
| 305 |
+
"""
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
if inputs_len != 0:
|
| 309 |
+
for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
|
| 310 |
+
if config.aot_inductor.use_minimal_arrayref_interface:
|
| 311 |
+
self.prefix.writeline(
|
| 312 |
+
f"auto {input_key} = std::get<{idx}>(inputs);"
|
| 313 |
+
)
|
| 314 |
+
continue
|
| 315 |
+
# unwrap input tensor back to scalar
|
| 316 |
+
if isinstance(V.graph.graph_inputs[input_key], sympy.Expr):
|
| 317 |
+
from ..graph import may_get_constant_buffer_dtype
|
| 318 |
+
|
| 319 |
+
dtype = may_get_constant_buffer_dtype(
|
| 320 |
+
V.graph.graph_inputs[input_key] # type: ignore[arg-type]
|
| 321 |
+
)
|
| 322 |
+
assert dtype is not None, (
|
| 323 |
+
"Fails to get the dtype of the sympy.Expr"
|
| 324 |
+
)
|
| 325 |
+
self.codegen_tensor_item(
|
| 326 |
+
dtype, f"inputs[{idx}]", input_key, self.prefix
|
| 327 |
+
)
|
| 328 |
+
else:
|
| 329 |
+
self.prefix.writeline(
|
| 330 |
+
f"auto {input_key} = std::move(inputs[{idx}]);"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
assert all(
|
| 334 |
+
isinstance(v, torch.Tensor) for v in list(V.graph.constants.values())
|
| 335 |
+
), "Expect all constants to be Tensor"
|
| 336 |
+
for idx, constants_key in enumerate(V.graph.constants.keys()):
|
| 337 |
+
if V.graph.aot_mode:
|
| 338 |
+
# Weights are stored in constants_ and owned by RAIIAtenTensorHandle there.
|
| 339 |
+
# Don't call std::move here because it will cause constants_ to lose the ownership.
|
| 340 |
+
self.prefix.writeline(
|
| 341 |
+
f"""auto {constants_key} = constants_->at({idx});"""
|
| 342 |
+
)
|
| 343 |
+
else:
|
| 344 |
+
# Append constants as inputs to the graph
|
| 345 |
+
constants_idx = inputs_len + idx
|
| 346 |
+
self.prefix.writeline(
|
| 347 |
+
f"auto {constants_key} = std::move(inputs[{constants_idx}]);"
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
self.codegen_inputs()
|
| 351 |
+
|
| 352 |
+
if V.graph.aot_mode:
|
| 353 |
+
if not V.graph.is_const_graph:
|
| 354 |
+
if config.aot_inductor.use_minimal_arrayref_interface:
|
| 355 |
+
# TODO: input shape checking for regular tensor interface as well?
|
| 356 |
+
self.codegen_input_numel_asserts()
|
| 357 |
+
else:
|
| 358 |
+
self.prefix.writeline("inputs.clear();")
|
| 359 |
+
self.prefix.writeline(
|
| 360 |
+
"[[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());"
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
def generate_return(self, output_refs: list[str]):
|
| 364 |
+
cst_names = V.graph.constants.keys()
|
| 365 |
+
arr_iface = (
|
| 366 |
+
not V.graph.is_const_graph
|
| 367 |
+
and config.aot_inductor.use_minimal_arrayref_interface
|
| 368 |
+
) # For brevity.
|
| 369 |
+
|
| 370 |
+
def use_thread_local_cached_output_tensor(idx, output):
|
| 371 |
+
cached_output_name = f"cached_output_{next(self.cached_output_id)}"
|
| 372 |
+
cache_type = "Array" if arr_iface else "Tensor"
|
| 373 |
+
self.wrapper_call.writeline(
|
| 374 |
+
f"thread_local ThreadLocalCachedOutput{cache_type}<std::decay_t<decltype({output})>> "
|
| 375 |
+
f"{cached_output_name}({output});"
|
| 376 |
+
)
|
| 377 |
+
if arr_iface:
|
| 378 |
+
self.wrapper_call.writeline(
|
| 379 |
+
f"{cached_output_name}.copy_data_from({output});"
|
| 380 |
+
)
|
| 381 |
+
output_entry = f"std::get<{idx}>(output_arrayref_tensors)"
|
| 382 |
+
element_type = f"std::decay_t<decltype({output_entry}.data()[0])>"
|
| 383 |
+
self.wrapper_call.writeline(
|
| 384 |
+
f"{output_entry} = {cached_output_name}.arrayref_tensor<{element_type}>();"
|
| 385 |
+
)
|
| 386 |
+
else:
|
| 387 |
+
self.wrapper_call.writeline(
|
| 388 |
+
f"{cached_output_name}.copy_data_from({output});"
|
| 389 |
+
)
|
| 390 |
+
self.wrapper_call.writeline(
|
| 391 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));"
|
| 392 |
+
)
|
| 393 |
+
self.wrapper_call.writeline(
|
| 394 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), "
|
| 395 |
+
f"output_handles[{idx}]));"
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
if arr_iface:
|
| 399 |
+
self.wrapper_call.writeline(
|
| 400 |
+
"AOTInductorModelOutputs output_arrayref_tensors;"
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
output2idx: dict[str, int] = {}
|
| 404 |
+
for idx, output in enumerate(output_refs):
|
| 405 |
+
if output == "nullptr":
|
| 406 |
+
continue
|
| 407 |
+
|
| 408 |
+
is_constant_buffer = output in cst_names
|
| 409 |
+
output_buffer = V.graph.graph_outputs[idx]
|
| 410 |
+
if isinstance(output_buffer, ir.BaseView):
|
| 411 |
+
output_storage = output_buffer.unwrap_view()
|
| 412 |
+
if isinstance(output_storage.data, ir.ConstantBuffer):
|
| 413 |
+
is_constant_buffer = True
|
| 414 |
+
|
| 415 |
+
if isinstance(output_buffer, ir.ShapeAsConstantBuffer):
|
| 416 |
+
# Need to wrap scalar into tensor as the main function returns a vector of tensors
|
| 417 |
+
output_tensor = self.codegen_scalar_to_tensor(output)
|
| 418 |
+
self.wrapper_call.writeline(
|
| 419 |
+
f"output_handles[{idx}] = {output_tensor}.release();"
|
| 420 |
+
)
|
| 421 |
+
continue
|
| 422 |
+
|
| 423 |
+
output_is_tensor_handle_expr = (
|
| 424 |
+
f"std::is_same_v<std::decay_t<decltype({output})>,"
|
| 425 |
+
"RAIIAtenTensorHandle> || "
|
| 426 |
+
f"std::is_same_v<std::decay_t<decltype({output})>,"
|
| 427 |
+
"AtenTensorHandle> || "
|
| 428 |
+
f"std::is_same_v<std::decay_t<decltype({output})>,"
|
| 429 |
+
"ConstantHandle>"
|
| 430 |
+
)
|
| 431 |
+
self.wrapper_call.writeline(
|
| 432 |
+
f"if constexpr ({output_is_tensor_handle_expr}) {{"
|
| 433 |
+
)
|
| 434 |
+
with self.wrapper_call.indent():
|
| 435 |
+
if arr_iface:
|
| 436 |
+
cached_output_name = f"cached_output_{next(self.cached_output_id)}"
|
| 437 |
+
self.wrapper_call.writeline(
|
| 438 |
+
f"thread_local RAIIAtenTensorHandle {cached_output_name};"
|
| 439 |
+
)
|
| 440 |
+
if is_constant_buffer:
|
| 441 |
+
# NOTE(return_constant): In some rare cases where we return
|
| 442 |
+
# a constant, we have to return a copy of this constant,
|
| 443 |
+
# because (1) constants are not owned by the Model instance
|
| 444 |
+
# (2) constants remain the same cross inference runs,
|
| 445 |
+
# assuming they are not updated at runtime Basically, we
|
| 446 |
+
# cannot release or transfer the ownership of any original
|
| 447 |
+
# constant to the user.
|
| 448 |
+
self.wrapper_call.writeline(
|
| 449 |
+
f"AtenTensorHandle {cached_output_name}_tmp;"
|
| 450 |
+
)
|
| 451 |
+
self.wrapper_call.writeline(
|
| 452 |
+
f"aoti_torch_clone({output}, &{cached_output_name}_tmp);"
|
| 453 |
+
)
|
| 454 |
+
self.wrapper_call.writeline(
|
| 455 |
+
f"{cached_output_name} = {cached_output_name}_tmp;"
|
| 456 |
+
)
|
| 457 |
+
else:
|
| 458 |
+
self.wrapper_call.writeline(
|
| 459 |
+
f"{cached_output_name} = {output}.release();"
|
| 460 |
+
)
|
| 461 |
+
self.wrapper_call.writeline(
|
| 462 |
+
f"convert_handle_to_arrayref_tensor({cached_output_name}, "
|
| 463 |
+
f"std::get<{idx}>(output_arrayref_tensors));"
|
| 464 |
+
)
|
| 465 |
+
else:
|
| 466 |
+
if is_constant_buffer:
|
| 467 |
+
# See NOTE(return_constant) above.
|
| 468 |
+
self.wrapper_call.writeline(
|
| 469 |
+
f"aoti_torch_clone({output}, &output_handles[{idx}]);"
|
| 470 |
+
)
|
| 471 |
+
else:
|
| 472 |
+
if output in output2idx:
|
| 473 |
+
src_idx = output2idx[output]
|
| 474 |
+
self.wrapper_call.writeline(
|
| 475 |
+
f"output_handles[{idx}] = output_handles[{src_idx}];"
|
| 476 |
+
)
|
| 477 |
+
else:
|
| 478 |
+
self.wrapper_call.writeline(
|
| 479 |
+
f"output_handles[{idx}] = {output}.release();"
|
| 480 |
+
)
|
| 481 |
+
self.wrapper_call.writeline("} else {")
|
| 482 |
+
with self.wrapper_call.indent():
|
| 483 |
+
use_thread_local_cached_output_tensor(idx, output)
|
| 484 |
+
self.wrapper_call.writeline("}")
|
| 485 |
+
|
| 486 |
+
if output not in output2idx:
|
| 487 |
+
output2idx[output] = idx
|
| 488 |
+
if arr_iface:
|
| 489 |
+
self.wrapper_call.writeline("return output_arrayref_tensors;")
|
| 490 |
+
|
| 491 |
+
def memory_plan(self):
|
| 492 |
+
from .memory_planning import MemoryPlanner
|
| 493 |
+
|
| 494 |
+
self.lines = MemoryPlanner(self).plan(self.lines)
|
| 495 |
+
# TODO: integrate memory planning & stack allocation?
|
| 496 |
+
self.allow_stack_allocation = False
|
| 497 |
+
|
| 498 |
+
def memory_plan_reuse(self):
|
| 499 |
+
out_names = V.graph.get_output_names()
|
| 500 |
+
|
| 501 |
+
while (
|
| 502 |
+
self.lines
|
| 503 |
+
and isinstance(self.lines[-1], MemoryPlanningLine)
|
| 504 |
+
# TODO: this seems legit, NullLine has no node
|
| 505 |
+
and self.lines[-1].node.name not in out_names # type: ignore[attr-defined]
|
| 506 |
+
):
|
| 507 |
+
# these lines will be pointless
|
| 508 |
+
self.lines.pop()
|
| 509 |
+
|
| 510 |
+
# codegen allocations in two passes
|
| 511 |
+
planning_states = [MemoryPlanningState()]
|
| 512 |
+
past_planning_states = []
|
| 513 |
+
for i in range(len(self.lines)):
|
| 514 |
+
line = self.lines[i]
|
| 515 |
+
if isinstance(line, MemoryPlanningLine):
|
| 516 |
+
self.lines[i] = line.plan(planning_states[-1])
|
| 517 |
+
elif isinstance(line, EnterSubgraphLine):
|
| 518 |
+
planning_states.append(MemoryPlanningState())
|
| 519 |
+
elif isinstance(line, ExitSubgraphLine):
|
| 520 |
+
past_planning_states.append(planning_states.pop())
|
| 521 |
+
past_planning_states.append(planning_states.pop())
|
| 522 |
+
assert len(planning_states) == 0
|
| 523 |
+
|
| 524 |
+
# conservatively use the sum of all allocated buffer sizes
|
| 525 |
+
# in potentially nested scopes as the total allocated size
|
| 526 |
+
total_allocated_buffer_size = sum(
|
| 527 |
+
s.total_allocated_buffer_size for s in past_planning_states
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
self.allow_stack_allocation = (
|
| 531 |
+
self.allow_stack_allocation is not False
|
| 532 |
+
and config.aot_inductor.allow_stack_allocation
|
| 533 |
+
and total_allocated_buffer_size <= MAX_STACK_ALLOCATION_SIZE
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
def can_stack_allocate_buffer(self, buffer):
|
| 537 |
+
return (
|
| 538 |
+
self.allow_stack_allocation
|
| 539 |
+
and buffer.get_device().type == "cpu"
|
| 540 |
+
and self.can_prove_buffer_has_static_shape(buffer)
|
| 541 |
+
and ir.is_contiguous_strides_for_shape(
|
| 542 |
+
buffer.get_stride(), buffer.get_size()
|
| 543 |
+
)
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
def make_buffer_free(self, buffer):
|
| 547 |
+
return (
|
| 548 |
+
""
|
| 549 |
+
if isinstance(buffer.get_output_spec(), ir.MultiOutputLayout)
|
| 550 |
+
or (V.graph.aot_mode and buffer.get_name() in self.stack_allocated_buffers)
|
| 551 |
+
or (
|
| 552 |
+
config.aot_inductor.use_minimal_arrayref_interface
|
| 553 |
+
and V.graph.aot_mode
|
| 554 |
+
and buffer.get_name() in V.graph.graph_inputs
|
| 555 |
+
)
|
| 556 |
+
else f"{buffer.get_name()}.reset();"
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
def make_buffer_allocation(self, buffer):
|
| 560 |
+
return self.make_allocation(
|
| 561 |
+
buffer.get_name(),
|
| 562 |
+
buffer.get_device(),
|
| 563 |
+
buffer.get_dtype(),
|
| 564 |
+
buffer.get_size(),
|
| 565 |
+
buffer.get_stride(),
|
| 566 |
+
buffer if self.can_stack_allocate_buffer(buffer) else None,
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
def make_allocation(
|
| 570 |
+
self, name, device, dtype, shape, stride, buffer_if_can_stack_allocate=None
|
| 571 |
+
):
|
| 572 |
+
orig_stride = stride
|
| 573 |
+
device_str = self.codegen_device(device)
|
| 574 |
+
dtype_code = self.codegen_dtype(dtype)
|
| 575 |
+
size = self.codegen_shape_tuple(shape)
|
| 576 |
+
stride = self.codegen_shape_tuple(orig_stride)
|
| 577 |
+
size_array_var = self.codegen_int_array_var(
|
| 578 |
+
size,
|
| 579 |
+
self.wrapper_call.writeline,
|
| 580 |
+
known_statically=self.is_statically_known_list_of_ints(shape),
|
| 581 |
+
graph=self.get_codegened_graph(),
|
| 582 |
+
)
|
| 583 |
+
stride_array_var = self.codegen_int_array_var(
|
| 584 |
+
stride,
|
| 585 |
+
self.wrapper_call.writeline,
|
| 586 |
+
known_statically=self.is_statically_known_list_of_ints(orig_stride),
|
| 587 |
+
graph=self.get_codegened_graph(),
|
| 588 |
+
)
|
| 589 |
+
device_type, device_id = device_str.split(",")
|
| 590 |
+
device_idx = "this->device_idx_" if V.graph.aot_mode else device_id
|
| 591 |
+
if buffer_if_can_stack_allocate is not None:
|
| 592 |
+
self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate
|
| 593 |
+
cpp_type = DTYPE_TO_CPP[dtype]
|
| 594 |
+
numel = buffer_if_can_stack_allocate.get_numel()
|
| 595 |
+
# Note: we don't zero storage because empty_strided doesn't zero either.
|
| 596 |
+
self.wrapper_call.writeline(f"{cpp_type} {name}_storage[{numel}];")
|
| 597 |
+
args = [
|
| 598 |
+
f"{name}_storage",
|
| 599 |
+
size_array_var,
|
| 600 |
+
stride_array_var,
|
| 601 |
+
device_type,
|
| 602 |
+
device_idx,
|
| 603 |
+
]
|
| 604 |
+
return f"ArrayRefTensor<{cpp_type}> {name}({', '.join(args)});"
|
| 605 |
+
|
| 606 |
+
args = [
|
| 607 |
+
str(len(shape)),
|
| 608 |
+
size_array_var,
|
| 609 |
+
stride_array_var,
|
| 610 |
+
dtype_code,
|
| 611 |
+
device_type,
|
| 612 |
+
device_idx,
|
| 613 |
+
f"&{name}_handle",
|
| 614 |
+
]
|
| 615 |
+
|
| 616 |
+
self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;")
|
| 617 |
+
self.wrapper_call.writeline(
|
| 618 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));"
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
return f"RAIIAtenTensorHandle {name}({name}_handle);"
|
| 622 |
+
|
| 623 |
+
def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool):
|
| 624 |
+
assert old.get_dtype() == new.get_dtype()
|
| 625 |
+
old_name = old.get_name()
|
| 626 |
+
new_name = new.get_name()
|
| 627 |
+
del_line = ";"
|
| 628 |
+
if old_name not in V.graph.get_output_names() and delete_old:
|
| 629 |
+
del_line = f"; {self.make_buffer_free(old)}"
|
| 630 |
+
|
| 631 |
+
if old.get_size() == new.get_size() and old.get_stride() == new.get_stride():
|
| 632 |
+
if old_name in self.stack_allocated_buffers:
|
| 633 |
+
self.stack_allocated_buffers[new_name] = new
|
| 634 |
+
return self.codegen_exact_buffer_reuse(old_name, new_name, del_line)
|
| 635 |
+
|
| 636 |
+
reinterpret_view = self.codegen_reinterpret_view(
|
| 637 |
+
old, new.get_size(), new.get_stride(), 0, self.wrapper_call.writeline
|
| 638 |
+
)
|
| 639 |
+
if reinterpret_view in self.stack_allocated_buffers:
|
| 640 |
+
self.stack_allocated_buffers[new_name] = new
|
| 641 |
+
# The only way to get into this case is via an exact buffer reuse, since all
|
| 642 |
+
# other options result in a new tensor handle.
|
| 643 |
+
return self.codegen_exact_buffer_reuse(old_name, new_name, del_line)
|
| 644 |
+
return f"{self.declare}{new_name} = {reinterpret_view}{del_line} // reuse"
|
| 645 |
+
|
| 646 |
+
def _assert_safe_to_use_borrow_arrayref_tensor_as_tensor(self):
|
| 647 |
+
# Borrowing arguments to shim functions is only safe because we know
|
| 648 |
+
# that the arguments can't be stack-allocated. Otherwise, to be sure
|
| 649 |
+
# we can't return a dangling pointer, we need to either 1) be
|
| 650 |
+
# certain that the shim function cannot return an alias of a
|
| 651 |
+
# borrowed argument, or 2) be certain that the returned Tensor from
|
| 652 |
+
# the shim function cannot escape.
|
| 653 |
+
assert self.is_safe_to_use_borrow_arrayref_tensor_as_tensor(), (
|
| 654 |
+
"borrowing arguments to shim functions is unsafe with "
|
| 655 |
+
"stack allocation on! (see comment above this assertion)"
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
def is_safe_to_use_borrow_arrayref_tensor_as_tensor(self):
|
| 659 |
+
return not self.allow_stack_allocation and not self.stack_allocated_buffers
|
| 660 |
+
|
| 661 |
+
def generate_c_shim_extern_kernel_call(
|
| 662 |
+
self, kernel: str, args: list[str], device: str, **_
|
| 663 |
+
) -> None:
|
| 664 |
+
# In the abi_compatible mode, we call fallback aten ops through a C shim layer
|
| 665 |
+
# Setting self.allow_stack_allocation to False because the exchange between
|
| 666 |
+
# ArrayRefTensor and at::Tensor is still fragile.
|
| 667 |
+
self.allow_stack_allocation = False
|
| 668 |
+
|
| 669 |
+
wrapped_args = []
|
| 670 |
+
for arg in args:
|
| 671 |
+
# We only really *need* borrow_arrayref_tensor_as_tensor for
|
| 672 |
+
# ArrayRefTensors. The code flowing into here uses `0` for nullptr, which
|
| 673 |
+
# borrow_arrayref_tensor_as_tensor would blindly coerce to int, so just
|
| 674 |
+
# avoid wrapping integers. Name matching is to find tensor is hacky, but
|
| 675 |
+
# fixing all the ArrayRefTensor issues is not a priority for now.
|
| 676 |
+
if isinstance(arg, str) and arg.startswith(
|
| 677 |
+
("buf", "arg", "wrap_with_raii_handle_if_needed")
|
| 678 |
+
):
|
| 679 |
+
self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()
|
| 680 |
+
arg = f"borrow_arrayref_tensor_as_tensor({arg})"
|
| 681 |
+
wrapped_args.append(arg)
|
| 682 |
+
|
| 683 |
+
super().generate_c_shim_extern_kernel_call(
|
| 684 |
+
kernel, wrapped_args, device, debug_args=args
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
def generate_scatter_fallback(
|
| 688 |
+
self,
|
| 689 |
+
output,
|
| 690 |
+
inputs,
|
| 691 |
+
cpp_kernel_name,
|
| 692 |
+
python_kernel_name,
|
| 693 |
+
src_is_tensor,
|
| 694 |
+
reduce,
|
| 695 |
+
kwargs,
|
| 696 |
+
):
|
| 697 |
+
# No stack allocation when there is a fallback op
|
| 698 |
+
self.allow_stack_allocation = False
|
| 699 |
+
|
| 700 |
+
# call the ABI shim function instead of the ATen one
|
| 701 |
+
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
|
| 702 |
+
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
|
| 703 |
+
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
|
| 704 |
+
self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()
|
| 705 |
+
inputs_wrapped = [
|
| 706 |
+
(f"borrow_arrayref_tensor_as_tensor({x})" if isinstance(x, str) else str(x))
|
| 707 |
+
for x in inputs
|
| 708 |
+
]
|
| 709 |
+
line = f"{cpp_kernel_name}(borrow_arrayref_tensor_as_tensor({output}), {','.join(inputs_wrapped)}"
|
| 710 |
+
|
| 711 |
+
if python_kernel_name.startswith("aten.scatter_reduce"):
|
| 712 |
+
line += f", {','.join(kwargs)}"
|
| 713 |
+
else:
|
| 714 |
+
if src_is_tensor:
|
| 715 |
+
if reduce:
|
| 716 |
+
line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
|
| 717 |
+
else:
|
| 718 |
+
assert reduce is None, (
|
| 719 |
+
"Expect reduce to be None for aten.scatter_ with scalar src"
|
| 720 |
+
)
|
| 721 |
+
line += ");"
|
| 722 |
+
self.writeline(line)
|
| 723 |
+
|
| 724 |
+
def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
|
| 725 |
+
# No stack allocation when there is a fallback op
|
| 726 |
+
self.allow_stack_allocation = False
|
| 727 |
+
|
| 728 |
+
self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()
|
| 729 |
+
# TODO: update aoti_torch_index_put_out in ir.py to use autogen out version
|
| 730 |
+
# See the comment in codegen_reinterpret_view about why having something like
|
| 731 |
+
# RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the corresponding
|
| 732 |
+
# tensor prematurely deallocated, thus the temporary array trick here.
|
| 733 |
+
indices_str = self._generate_temporary_array_pointer(
|
| 734 |
+
"AtenTensorHandle",
|
| 735 |
+
[f"borrow_arrayref_tensor_as_tensor({i})" for i in indices],
|
| 736 |
+
)
|
| 737 |
+
args = [
|
| 738 |
+
f"borrow_arrayref_tensor_as_tensor({x})",
|
| 739 |
+
indices_str,
|
| 740 |
+
str(len(indices)),
|
| 741 |
+
f"borrow_arrayref_tensor_as_tensor({values})",
|
| 742 |
+
accumulate,
|
| 743 |
+
]
|
| 744 |
+
args.insert(
|
| 745 |
+
0, f"borrow_arrayref_tensor_as_tensor({x})"
|
| 746 |
+
) # set x as the output tensor, this fallback mutates x.
|
| 747 |
+
self.writeline(self.wrap_kernel_call(kernel, args))
|
| 748 |
+
|
| 749 |
+
def generate_fallback_kernel_with_runtime_lookup(
|
| 750 |
+
self,
|
| 751 |
+
buf_name: str,
|
| 752 |
+
python_kernel_name: str,
|
| 753 |
+
get_args: Callable[[], Sequence[str]],
|
| 754 |
+
op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
|
| 755 |
+
raw_args: Sequence[Any],
|
| 756 |
+
outputs: Sequence[ir.Buffer],
|
| 757 |
+
) -> None:
|
| 758 |
+
# No stack allocation when there is a fallback op
|
| 759 |
+
self.allow_stack_allocation = False
|
| 760 |
+
super().generate_fallback_kernel_with_runtime_lookup(
|
| 761 |
+
buf_name, python_kernel_name, get_args, op_overload, raw_args, outputs
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
def codegen_device_copy(self, src, dst, non_blocking: bool):
|
| 765 |
+
# aoti_torch_tensor_copy_ takes AtenTensorHandle as input,
|
| 766 |
+
# while stack-allocation results in ArrayRefTensor
|
| 767 |
+
# so disable stack allocation here
|
| 768 |
+
self.allow_stack_allocation = False
|
| 769 |
+
self.writeline(
|
| 770 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));"
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
def codegen_reinterpret_view(
|
| 774 |
+
self,
|
| 775 |
+
data,
|
| 776 |
+
size,
|
| 777 |
+
stride,
|
| 778 |
+
offset,
|
| 779 |
+
writeline: Callable[..., None],
|
| 780 |
+
dtype=None,
|
| 781 |
+
) -> str:
|
| 782 |
+
"""Returns a newly-created, temporary RAII tensor handle containing the
|
| 783 |
+
reinterpreted tensor data. Callers of this function are responsible for saving
|
| 784 |
+
the handle if persistent access is needed."""
|
| 785 |
+
dim = str(len(size))
|
| 786 |
+
|
| 787 |
+
def create_reinterpret_call() -> str:
|
| 788 |
+
args = [
|
| 789 |
+
f"{data.get_name()}",
|
| 790 |
+
dim,
|
| 791 |
+
self.codegen_int_array_var(
|
| 792 |
+
self.codegen_shape_tuple(size),
|
| 793 |
+
writeline,
|
| 794 |
+
known_statically=self.is_statically_known_list_of_ints(size),
|
| 795 |
+
graph=self.get_codegened_graph(),
|
| 796 |
+
),
|
| 797 |
+
self.codegen_int_array_var(
|
| 798 |
+
self.codegen_shape_tuple(stride),
|
| 799 |
+
writeline,
|
| 800 |
+
known_statically=self.is_statically_known_list_of_ints(stride),
|
| 801 |
+
graph=self.get_codegened_graph(),
|
| 802 |
+
),
|
| 803 |
+
offset,
|
| 804 |
+
]
|
| 805 |
+
return f"wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper({', '.join(args)}))"
|
| 806 |
+
|
| 807 |
+
def create_new_tensor_handle() -> tuple[str, list[str]]:
|
| 808 |
+
# Calling reset() on ArrayRefTensor does nothing, since the array is
|
| 809 |
+
# const-allocated on the stack. Thus, it's safe to return a reference to
|
| 810 |
+
# the original array.
|
| 811 |
+
if (name := data.get_name()) in self.stack_allocated_buffers:
|
| 812 |
+
return name, []
|
| 813 |
+
|
| 814 |
+
tmp_AtenTensorHandle = f"tmp_{name}_{next(self.tmp_tensor_id)}"
|
| 815 |
+
tmp_call_strs = [
|
| 816 |
+
f"AtenTensorHandle {tmp_AtenTensorHandle};",
|
| 817 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));",
|
| 818 |
+
]
|
| 819 |
+
return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs
|
| 820 |
+
|
| 821 |
+
if (
|
| 822 |
+
size == data.layout.size
|
| 823 |
+
and stride == data.layout.stride
|
| 824 |
+
and offset == data.layout.offset
|
| 825 |
+
and (dtype is None or dtype == data.dtype)
|
| 826 |
+
):
|
| 827 |
+
final_tensor_str, call_strs = create_new_tensor_handle()
|
| 828 |
+
for line in call_strs:
|
| 829 |
+
writeline(line)
|
| 830 |
+
return final_tensor_str
|
| 831 |
+
|
| 832 |
+
return super().codegen_reinterpret_view(
|
| 833 |
+
data, size, stride, offset, writeline, dtype
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
def val_to_arg_str(self, val, type_=None) -> str:
|
| 837 |
+
if (
|
| 838 |
+
val is not None
|
| 839 |
+
and isinstance(type_, torch.OptionalType)
|
| 840 |
+
and isinstance(type_.getElementType(), torch.TensorType)
|
| 841 |
+
):
|
| 842 |
+
# Handle optional tensors as a special case, as in the parent class.
|
| 843 |
+
base_handle = self.val_to_arg_str(val, torch.TensorType)
|
| 844 |
+
if config.aot_inductor.use_minimal_arrayref_interface:
|
| 845 |
+
if self.is_safe_to_use_borrow_arrayref_tensor_as_tensor():
|
| 846 |
+
base_handle = f"borrow_arrayref_tensor_as_tensor({base_handle})"
|
| 847 |
+
else:
|
| 848 |
+
base_handle = f"copy_arrayref_tensor_to_tensor({base_handle})"
|
| 849 |
+
return f"&temporary_reference({base_handle}.get())"
|
| 850 |
+
|
| 851 |
+
return super().val_to_arg_str(val, type_)
|
| 852 |
+
|
| 853 |
+
def codegen_tensor_item(
|
| 854 |
+
self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None
|
| 855 |
+
):
|
| 856 |
+
dtype_str = str(dtype).split(".")[-1]
|
| 857 |
+
writer = indented_buffer or self
|
| 858 |
+
|
| 859 |
+
if dtype == torch.float16 or dtype == torch.bfloat16:
|
| 860 |
+
scalar_tmp = f"{scalar}_tmp"
|
| 861 |
+
writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};")
|
| 862 |
+
|
| 863 |
+
# We know that item_ doesn't alias the input, so borrowing should be safe.
|
| 864 |
+
tensor = f"borrow_arrayref_tensor_as_tensor({tensor})"
|
| 865 |
+
|
| 866 |
+
writer.writeline(
|
| 867 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));"
|
| 868 |
+
)
|
| 869 |
+
writer.writeline(f"float {scalar} = float({scalar_tmp});")
|
| 870 |
+
else:
|
| 871 |
+
writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};")
|
| 872 |
+
|
| 873 |
+
# We know that item_ doesn't alias the input, so borrowing should be safe.
|
| 874 |
+
tensor = f"borrow_arrayref_tensor_as_tensor({tensor})"
|
| 875 |
+
|
| 876 |
+
writer.writeline(
|
| 877 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));"
|
| 878 |
+
)
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_gpu.py
ADDED
|
@@ -0,0 +1,717 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import dataclasses
|
| 5 |
+
import re
|
| 6 |
+
from itertools import count, zip_longest
|
| 7 |
+
from typing import Any, Optional, Union
|
| 8 |
+
from typing_extensions import Self
|
| 9 |
+
|
| 10 |
+
import sympy
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import dtype as torch_dtype
|
| 14 |
+
from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
|
| 15 |
+
from torch._inductor.runtime.runtime_utils import dynamo_timed
|
| 16 |
+
|
| 17 |
+
from .. import config
|
| 18 |
+
from ..codecache import CudaKernelParamCache
|
| 19 |
+
from ..ir import (
|
| 20 |
+
GraphPartitionSignature,
|
| 21 |
+
TensorBox,
|
| 22 |
+
TMADescriptorExperimental,
|
| 23 |
+
TMADescriptorStable,
|
| 24 |
+
)
|
| 25 |
+
from ..utils import cache_on_self, get_gpu_type, GPU_ALIGN_BYTES, IndentedBuffer
|
| 26 |
+
from ..virtualized import V
|
| 27 |
+
from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
| 28 |
+
from .common import get_device_op_overrides, TritonScratchWorkspace
|
| 29 |
+
from .cpp_utils import cexpr
|
| 30 |
+
from .cpp_wrapper_cpu import CppWrapperCpu
|
| 31 |
+
from .multi_kernel import MultiKernelCall
|
| 32 |
+
from .triton_utils import should_unwrap_unspec_arg
|
| 33 |
+
from .wrapper import PythonWrapperCodegen, SymbolicCallArg
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
_cpp_string_literal_escapes = {
|
| 37 |
+
"\\": "\\\\",
|
| 38 |
+
'"': '\\"',
|
| 39 |
+
"\n": "\\n",
|
| 40 |
+
"\t": "\\t",
|
| 41 |
+
"\r": "\\r",
|
| 42 |
+
}
|
| 43 |
+
_cpp_string_literal_pattern = re.compile(r'["\\\n\t\r]')
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def cpp_string_literal(s: str) -> str:
|
| 47 |
+
escaped = _cpp_string_literal_pattern.sub(
|
| 48 |
+
lambda match: _cpp_string_literal_escapes[match.group(0)], s
|
| 49 |
+
)
|
| 50 |
+
return f'"{escaped}"'
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclasses.dataclass
|
| 54 |
+
class DeferredTritonCallWrapper:
|
| 55 |
+
"""
|
| 56 |
+
When using cpp wrapper, GPU kernel load and launch needs to wait for Triton kernels
|
| 57 |
+
to be tuned and stored as cubin files, so use a deferred generating the final wrapper around
|
| 58 |
+
the triton kernel until right before the prefix is written.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
wrapper_name: str
|
| 62 |
+
kernel_name: str
|
| 63 |
+
kernel_name_to_body: dict[str, str]
|
| 64 |
+
arg_types: list[Any]
|
| 65 |
+
|
| 66 |
+
def generate(self, wrapper: CppWrapperGpu):
|
| 67 |
+
"""
|
| 68 |
+
Generate the GPU kernel definition, as well as load and launch code.
|
| 69 |
+
"""
|
| 70 |
+
prefix = wrapper.prefix
|
| 71 |
+
if self.kernel_name.startswith("multi_kernel_"):
|
| 72 |
+
# MultiKernel will select one kernel after running the autotune block
|
| 73 |
+
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
|
| 74 |
+
params = CudaKernelParamCache.get(self.kernel_name)
|
| 75 |
+
assert params, f"CudaKernelParamCache not populated for {self.kernel_name}"
|
| 76 |
+
def_args = params["def_args"]
|
| 77 |
+
arg_types = self.arg_types
|
| 78 |
+
inductor_meta = params["inductor_meta"]
|
| 79 |
+
|
| 80 |
+
if "extra_launcher_args" in inductor_meta and len(def_args) > len(arg_types):
|
| 81 |
+
# extra_launcher_args should already be in def_args
|
| 82 |
+
assert len(def_args) == len(arg_types) - len(
|
| 83 |
+
inductor_meta["extra_launcher_args"]
|
| 84 |
+
)
|
| 85 |
+
arg_types = arg_types + [SymbolicCallArg] * len(
|
| 86 |
+
inductor_meta["extra_launcher_args"]
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
if not V.graph.aot_mode:
|
| 90 |
+
prefix.writeline(
|
| 91 |
+
maybe_hipify_code_wrapper(
|
| 92 |
+
f"static {wrapper.device_codegen.cpp_kernel_type()} {self.kernel_name} = nullptr;"
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
kernel_var_name = self.kernel_name
|
| 96 |
+
else:
|
| 97 |
+
kernel_var_name = f"kernels_.{self.kernel_name}"
|
| 98 |
+
|
| 99 |
+
# tensors can be RAIIAtenTensorHandle or ConstantHandle, so make them template types
|
| 100 |
+
template_types = [
|
| 101 |
+
f"typename {name}_type_"
|
| 102 |
+
for name, arg_type in zip(def_args, arg_types)
|
| 103 |
+
if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg))
|
| 104 |
+
]
|
| 105 |
+
if V.graph.aot_mode:
|
| 106 |
+
template_types.append("typename kernels_type_")
|
| 107 |
+
if template_types:
|
| 108 |
+
prefix.writeline(f"template <{', '.join(template_types)}>")
|
| 109 |
+
prefix.writeline(f"static inline void {self.wrapper_name}(")
|
| 110 |
+
with prefix.indent():
|
| 111 |
+
assert len(def_args) == len(arg_types), (def_args, arg_types)
|
| 112 |
+
for name, arg_type in zip(def_args, arg_types):
|
| 113 |
+
if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg)):
|
| 114 |
+
prefix.writeline(f"const {name}_type_& {name},")
|
| 115 |
+
elif issubclass(arg_type, (SymbolicCallArg, sympy.Expr, int)):
|
| 116 |
+
prefix.writeline(f"int64_t {name},")
|
| 117 |
+
elif arg_type is float:
|
| 118 |
+
prefix.writeline(f"float {name},")
|
| 119 |
+
elif arg_type is bool:
|
| 120 |
+
prefix.writeline(f"bool {name},")
|
| 121 |
+
else:
|
| 122 |
+
raise ValueError(f"Unexpected arg type {arg_type}")
|
| 123 |
+
prefix.writeline("int32_t device_idx_,")
|
| 124 |
+
prefix.writeline(
|
| 125 |
+
maybe_hipify_code_wrapper(
|
| 126 |
+
f"{wrapper.device_codegen.cpp_stream_type()} stream_,"
|
| 127 |
+
)
|
| 128 |
+
)
|
| 129 |
+
if V.graph.aot_mode:
|
| 130 |
+
prefix.writeline("kernels_type_& kernels_,")
|
| 131 |
+
prefix.writeline(
|
| 132 |
+
"const std::optional<std::string>& cubin_dir_ = std::nullopt"
|
| 133 |
+
)
|
| 134 |
+
prefix.writeline("){")
|
| 135 |
+
with prefix.indent():
|
| 136 |
+
if V.graph.aot_mode:
|
| 137 |
+
# Emit the original Triton kernel for debugging purposes
|
| 138 |
+
prefix.writeline("/*")
|
| 139 |
+
prefix.splice(self.kernel_name_to_body[self.kernel_name])
|
| 140 |
+
prefix.writeline("*/")
|
| 141 |
+
self.generate_grid(prefix, inductor_meta, params)
|
| 142 |
+
self.generate_load_kernel(prefix, kernel_var_name, params)
|
| 143 |
+
self.generate_launch_kernel(prefix, wrapper, kernel_var_name, params)
|
| 144 |
+
prefix.writeline("}")
|
| 145 |
+
|
| 146 |
+
if not config.aot_inductor.embed_kernel_binary:
|
| 147 |
+
# Ensure the cubin file is included in the package
|
| 148 |
+
V.graph.wrapper_code.additional_files.append(
|
| 149 |
+
params[get_cpp_wrapper_cubin_path_name()]
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def generate_grid(
|
| 153 |
+
self,
|
| 154 |
+
prefix: IndentedBuffer,
|
| 155 |
+
inductor_meta: dict[str, Any],
|
| 156 |
+
params: dict[str, Any],
|
| 157 |
+
):
|
| 158 |
+
from ..runtime.triton_heuristics import GridExpr
|
| 159 |
+
|
| 160 |
+
grid = GridExpr.from_meta(inductor_meta, params["config"], mode="cpp")
|
| 161 |
+
for line in grid.prefix:
|
| 162 |
+
prefix.writeline(line)
|
| 163 |
+
prefix.splice(
|
| 164 |
+
f"""\
|
| 165 |
+
uint32_t grid_0 = {grid.x_grid};
|
| 166 |
+
uint32_t grid_1 = {grid.y_grid};
|
| 167 |
+
uint32_t grid_2 = {grid.z_grid};
|
| 168 |
+
"""
|
| 169 |
+
)
|
| 170 |
+
prefix.writeline("if (grid_0 == 0 || grid_1 == 0 || grid_2 == 0) return;")
|
| 171 |
+
|
| 172 |
+
def generate_load_kernel(self, prefix, kernel_var_name, params):
|
| 173 |
+
prefix.writeline(f"if ({kernel_var_name} == nullptr) {{")
|
| 174 |
+
with prefix.indent():
|
| 175 |
+
embed_kernel_args = [f"__{params['inductor_meta']['kernel_name']}_start"]
|
| 176 |
+
if torch.xpu.is_available():
|
| 177 |
+
# XPU needs the end address of the kernel to calculate the size of the kernel binary.
|
| 178 |
+
embed_kernel_args.append(
|
| 179 |
+
f"__{params['inductor_meta']['kernel_name']}_end"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
load_kernel_args = (
|
| 183 |
+
[
|
| 184 |
+
*embed_kernel_args,
|
| 185 |
+
cpp_string_literal(params["mangled_name"]),
|
| 186 |
+
str(params["shared_mem"]),
|
| 187 |
+
]
|
| 188 |
+
if V.graph.aot_mode and config.aot_inductor.embed_kernel_binary
|
| 189 |
+
else [
|
| 190 |
+
cpp_string_literal(params[get_cpp_wrapper_cubin_path_name()]),
|
| 191 |
+
cpp_string_literal(params["mangled_name"]),
|
| 192 |
+
str(params["shared_mem"]),
|
| 193 |
+
"cubin_dir_",
|
| 194 |
+
]
|
| 195 |
+
)
|
| 196 |
+
prefix.writeline(
|
| 197 |
+
f"{kernel_var_name} = loadKernel({', '.join(load_kernel_args)}); "
|
| 198 |
+
)
|
| 199 |
+
prefix.writeline("}")
|
| 200 |
+
|
| 201 |
+
def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params):
|
| 202 |
+
triton_meta = params["triton_meta"]
|
| 203 |
+
assert len(self.arg_types) == len(params["def_args"]), (
|
| 204 |
+
self.arg_types,
|
| 205 |
+
params["def_args"],
|
| 206 |
+
)
|
| 207 |
+
arg_type_loookup = dict(zip(params["def_args"], self.arg_types))
|
| 208 |
+
# difference between Python and C++ wrapper: C++ wrapper strips out equal_to_1 constants
|
| 209 |
+
call_args = [
|
| 210 |
+
name for name in params["call_args"] if name not in triton_meta["constants"]
|
| 211 |
+
]
|
| 212 |
+
arg_types = [arg_type_loookup[name] for name in call_args]
|
| 213 |
+
arg_signatures = [triton_meta["signature"][name] for name in call_args]
|
| 214 |
+
call_args_str = wrapper.generate_args_decl(
|
| 215 |
+
prefix,
|
| 216 |
+
call_args,
|
| 217 |
+
arg_types,
|
| 218 |
+
arg_signatures,
|
| 219 |
+
workspace_size=params.get("global_scratch") or 0,
|
| 220 |
+
)
|
| 221 |
+
prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};")
|
| 222 |
+
launch_kernel_args = [
|
| 223 |
+
kernel_var_name,
|
| 224 |
+
"grid_0",
|
| 225 |
+
"grid_1",
|
| 226 |
+
"grid_2",
|
| 227 |
+
str(params["num_warps"]),
|
| 228 |
+
str(params["shared_mem"]),
|
| 229 |
+
"kernel_args_",
|
| 230 |
+
"stream_",
|
| 231 |
+
]
|
| 232 |
+
prefix.writeline(f"launchKernel({', '.join(launch_kernel_args)});")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class CppWrapperGpu(CppWrapperCpu):
|
| 236 |
+
"""
|
| 237 |
+
Generates cpp wrapper for running on GPU and calls CUDA kernels
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
def __init__(self) -> None:
|
| 241 |
+
self.device = get_gpu_type()
|
| 242 |
+
self.device_codegen = get_device_op_overrides(self.device)
|
| 243 |
+
super().__init__()
|
| 244 |
+
self.grid_id = count()
|
| 245 |
+
self._kernel_name_to_body: dict[str, str] = {}
|
| 246 |
+
self._triton_call_wrappers: dict[str, DeferredTritonCallWrapper] = {}
|
| 247 |
+
self.autotune_input_prefix = "_REAL_AUTOTUNE_INPUT"
|
| 248 |
+
|
| 249 |
+
@staticmethod
|
| 250 |
+
def create(
|
| 251 |
+
is_subgraph: bool,
|
| 252 |
+
subgraph_name: Optional[str],
|
| 253 |
+
parent_wrapper: Optional[PythonWrapperCodegen],
|
| 254 |
+
partition_signatures: Optional[GraphPartitionSignature] = None,
|
| 255 |
+
):
|
| 256 |
+
# TODO - support subgraph codegen by lifting functions. Check the
|
| 257 |
+
# comment at CppWrapperCpu `codegen_subgraph` function.
|
| 258 |
+
return CppWrapperGpu()
|
| 259 |
+
|
| 260 |
+
def write_header(self):
|
| 261 |
+
if V.graph.is_const_graph:
|
| 262 |
+
# We do not write header for constant graph, it will be written by main module.
|
| 263 |
+
return
|
| 264 |
+
|
| 265 |
+
super().write_header()
|
| 266 |
+
self.header.splice(
|
| 267 |
+
maybe_hipify_code_wrapper(self.device_codegen.kernel_driver())
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
@cache_on_self
|
| 271 |
+
def write_tma_descriptor_helpers_once(self):
|
| 272 |
+
self.header.splice(self.device_codegen.tma_descriptor_helpers())
|
| 273 |
+
|
| 274 |
+
def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str:
|
| 275 |
+
name = f"stream{device_idx}"
|
| 276 |
+
self.writeline(
|
| 277 |
+
maybe_hipify_code_wrapper(
|
| 278 |
+
f"{self.device_codegen.cpp_stream_type()} {name};"
|
| 279 |
+
)
|
| 280 |
+
)
|
| 281 |
+
self.writeline(
|
| 282 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK({self.device_codegen.aoti_get_stream()}({device_idx}, (void**)&{name}));"
|
| 283 |
+
)
|
| 284 |
+
return name
|
| 285 |
+
|
| 286 |
+
def get_autotuning_input_name(self, idx):
|
| 287 |
+
return f"{self.autotune_input_prefix}_{idx}"
|
| 288 |
+
|
| 289 |
+
def codegen_inputs(self):
|
| 290 |
+
# See Note: [Input Alignment handling in Inductor]
|
| 291 |
+
#
|
| 292 |
+
# JIT Inductor does not guard on input alignment. It relies on copy_misaligned_inputs to
|
| 293 |
+
# copy misaligned inputs to aligned buffers. For AOTInductor, we need to do the same in cpp.
|
| 294 |
+
|
| 295 |
+
if config.is_fbcode():
|
| 296 |
+
# TODO: This is added because FC. Remove this once the newly added shim symbols,
|
| 297 |
+
# e.g. aoti_torch_clone_preserve_strides, have landed
|
| 298 |
+
return super().codegen_inputs()
|
| 299 |
+
|
| 300 |
+
if V.graph.aot_mode and V.graph.inputs_to_check:
|
| 301 |
+
for idx in V.graph.inputs_to_check:
|
| 302 |
+
input_name = V.graph.graph_input_names[idx]
|
| 303 |
+
assert input_name in V.graph.graph_inputs, (
|
| 304 |
+
f"{input_name} not found in graph inputs"
|
| 305 |
+
)
|
| 306 |
+
value = V.graph.graph_inputs[input_name]
|
| 307 |
+
assert isinstance(value, TensorBox), (
|
| 308 |
+
f"{input_name} is expected to be tensor but found as {type(value)}"
|
| 309 |
+
)
|
| 310 |
+
warn_msg = (
|
| 311 |
+
f"Input {idx} was compiled as {GPU_ALIGN_BYTES}-bytes aligned, "
|
| 312 |
+
"but it is not aligned at run time. Copying to an aligned tensor "
|
| 313 |
+
"to guarantee correctness, but expect a performance hit."
|
| 314 |
+
)
|
| 315 |
+
self.prefix.splice(
|
| 316 |
+
f"""
|
| 317 |
+
if ((long({input_name}.data_ptr()) & ({GPU_ALIGN_BYTES} -1)) != 0) {{
|
| 318 |
+
AOTI_TORCH_WARN("{warn_msg}");
|
| 319 |
+
AtenTensorHandle {input_name}_aligned;
|
| 320 |
+
aoti_torch_clone_preserve_strides({input_name}, &{input_name}_aligned);
|
| 321 |
+
{input_name} = std::move(RAIIAtenTensorHandle({input_name}_aligned));
|
| 322 |
+
}}
|
| 323 |
+
"""
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
super().codegen_inputs()
|
| 327 |
+
|
| 328 |
+
def _define_kernel_helper(
|
| 329 |
+
self,
|
| 330 |
+
kernel_name: str,
|
| 331 |
+
kernel_body: str,
|
| 332 |
+
metadata: Optional[str] = None,
|
| 333 |
+
gpu: bool = True,
|
| 334 |
+
cpp_definition: Optional[str] = None,
|
| 335 |
+
):
|
| 336 |
+
if gpu:
|
| 337 |
+
self._kernel_name_to_body[kernel_name] = kernel_body
|
| 338 |
+
if config.triton.autotune_at_compile_time:
|
| 339 |
+
# Call PythonWrapperCodegen to create the autotune code block
|
| 340 |
+
PythonWrapperCodegen._define_kernel_helper(
|
| 341 |
+
self, kernel_name, kernel_body, metadata, gpu, cpp_definition
|
| 342 |
+
)
|
| 343 |
+
else:
|
| 344 |
+
return CppWrapperCpu._define_kernel_helper(
|
| 345 |
+
self, kernel_name, kernel_body, metadata, gpu, cpp_definition
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
def generate(self, is_inference):
|
| 349 |
+
with dynamo_timed("CppWrapperGpu.generate", log_pt2_compile_event=True):
|
| 350 |
+
return super().generate(is_inference)
|
| 351 |
+
|
| 352 |
+
def finalize_prefix(self):
|
| 353 |
+
"""Define the triton kernels now that autotuning is finished"""
|
| 354 |
+
old_prefix = self.prefix # new content should go at start of prefix
|
| 355 |
+
|
| 356 |
+
# Generating triton kernel callers can modify the prefix (cached dtypes),
|
| 357 |
+
# so do this before running finalize_prefix(), but put the generated code
|
| 358 |
+
# after the finalize_prefix() code.
|
| 359 |
+
self.prefix = IndentedBuffer()
|
| 360 |
+
for kernel in self._triton_call_wrappers.values():
|
| 361 |
+
self.prefix.writeline("\n")
|
| 362 |
+
kernel.generate(self)
|
| 363 |
+
triton_prefix = self.prefix
|
| 364 |
+
|
| 365 |
+
self.prefix = IndentedBuffer()
|
| 366 |
+
super().finalize_prefix()
|
| 367 |
+
|
| 368 |
+
self.prefix.splice(triton_prefix)
|
| 369 |
+
|
| 370 |
+
self.prefix.writeline("\n")
|
| 371 |
+
self.prefix.splice(old_prefix)
|
| 372 |
+
|
| 373 |
+
def generate_tma_descriptor(self, desc):
|
| 374 |
+
self.write_tma_descriptor_helpers_once()
|
| 375 |
+
|
| 376 |
+
if isinstance(desc, TMADescriptorExperimental):
|
| 377 |
+
self._generate_experimental_tma_descriptor(desc)
|
| 378 |
+
else:
|
| 379 |
+
assert isinstance(desc, TMADescriptorStable)
|
| 380 |
+
self._generate_stable_tma_descriptor(desc)
|
| 381 |
+
|
| 382 |
+
def _generate_experimental_tma_descriptor(self, desc):
|
| 383 |
+
# generate data pointer for the source tensor
|
| 384 |
+
source = self.generate_args_decl(
|
| 385 |
+
code=self,
|
| 386 |
+
call_args=[self.val_to_arg_str(desc.tensor)],
|
| 387 |
+
arg_types=[desc.tensor.get_dtype()],
|
| 388 |
+
arg_signatures=[None],
|
| 389 |
+
# these args are passed to initNDTMADescriptor, which is NOT a triton kernel
|
| 390 |
+
is_triton_kernel=False,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
desc_name = desc.name
|
| 394 |
+
self.writeline(f"alignas(64) CUtensorMap {desc_name};")
|
| 395 |
+
|
| 396 |
+
# `source` is in the form of `&var_x`, where `var_x` is the data pointer
|
| 397 |
+
# (CUdeviceptr); we dereference `source` and cast to `void*` to pass to
|
| 398 |
+
# the data pointer of the source tensor to the helper function
|
| 399 |
+
# `init{1,2}DTMADescriptor`
|
| 400 |
+
ptr = f"reinterpret_cast<void*>(*({source}))"
|
| 401 |
+
dims = ", ".join(self.val_to_arg_str(dim) for dim in desc.dims)
|
| 402 |
+
block_dims = ", ".join(self.val_to_arg_str(dim) for dim in desc.block_dims)
|
| 403 |
+
element_size = self.val_to_arg_str(desc.element_size)
|
| 404 |
+
fn = f"init{desc.rank}DTMADescriptor"
|
| 405 |
+
args = f"&{desc_name}, {ptr}, {dims}, {block_dims}, {element_size}"
|
| 406 |
+
self.writeline(f"{fn}({args});")
|
| 407 |
+
|
| 408 |
+
def _generate_stable_tma_descriptor(self, desc):
|
| 409 |
+
source = self.generate_args_decl(
|
| 410 |
+
code=self,
|
| 411 |
+
call_args=[self.val_to_arg_str(desc.tensor)],
|
| 412 |
+
arg_types=[desc.tensor.get_dtype()],
|
| 413 |
+
arg_signatures=[None],
|
| 414 |
+
# these args are passed to initNDTMADescriptor, which is NOT a triton kernel
|
| 415 |
+
is_triton_kernel=False,
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
desc_name = desc.name
|
| 419 |
+
# Pack the relevant information into a StableTMADescriptor struct.
|
| 420 |
+
# See [Note: AOTI TMA Stable handling] for more details.
|
| 421 |
+
self.writeline(f"alignas(64) StableTMADescriptor {desc_name};")
|
| 422 |
+
|
| 423 |
+
def fill_array(name, values):
|
| 424 |
+
for i, val in enumerate(values):
|
| 425 |
+
self.writeline(f"{name}[{i}] = {val};")
|
| 426 |
+
|
| 427 |
+
ptr = f"reinterpret_cast<void*>(*({source}))"
|
| 428 |
+
rank = len(desc.tensor.get_size())
|
| 429 |
+
|
| 430 |
+
fill_array(f"{desc_name}.block_shape", desc.block_shape)
|
| 431 |
+
fill_array(f"{desc_name}.global_shape", desc.tensor.get_size())
|
| 432 |
+
fill_array(f"{desc_name}.strides", desc.tensor.get_stride())
|
| 433 |
+
|
| 434 |
+
element_size = self.val_to_arg_str(desc.tensor.get_dtype().itemsize)
|
| 435 |
+
fn = "initTMADescriptor"
|
| 436 |
+
args = ", ".join(
|
| 437 |
+
str(x)
|
| 438 |
+
for x in [
|
| 439 |
+
f"&{desc_name}.m",
|
| 440 |
+
ptr,
|
| 441 |
+
element_size,
|
| 442 |
+
rank,
|
| 443 |
+
f"{desc_name}.block_shape",
|
| 444 |
+
f"{desc_name}.global_shape",
|
| 445 |
+
f"{desc_name}.strides",
|
| 446 |
+
]
|
| 447 |
+
)
|
| 448 |
+
self.writeline(f"{fn}({args});")
|
| 449 |
+
|
| 450 |
+
def generate_args_decl(
|
| 451 |
+
self,
|
| 452 |
+
code: Union[IndentedBuffer, Self],
|
| 453 |
+
call_args,
|
| 454 |
+
arg_types,
|
| 455 |
+
arg_signatures,
|
| 456 |
+
is_triton_kernel=True,
|
| 457 |
+
workspace_size=0,
|
| 458 |
+
):
|
| 459 |
+
"""
|
| 460 |
+
Generates any declarations of args to pass into a kernel call, and then returns the arg names.
|
| 461 |
+
|
| 462 |
+
In more detail:
|
| 463 |
+
* declarations: e.g. this function has a side effect of generating lines like `auto var_0 = ...;`
|
| 464 |
+
* returns: a string with the list of args, e.g. "var_0, var_1"
|
| 465 |
+
|
| 466 |
+
call_args: list of call arguments
|
| 467 |
+
arg_types: list of argument types
|
| 468 |
+
arg_signatures: list with signatures of all the args
|
| 469 |
+
is_triton_kernel: whether these are passed into a triton kernel or not. In particular,
|
| 470 |
+
calls to triton kernels will have an additional global scratch space
|
| 471 |
+
arg injected at the front of the arg list.
|
| 472 |
+
"""
|
| 473 |
+
new_args: list[str] = []
|
| 474 |
+
|
| 475 |
+
# Add more cases for other types as needed
|
| 476 |
+
signature2dtype = {
|
| 477 |
+
"i32": "int32_t",
|
| 478 |
+
"i64": "int64_t",
|
| 479 |
+
"fp32": "float",
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
def signature_is_tma_desc(sig):
|
| 483 |
+
if not sig:
|
| 484 |
+
return False
|
| 485 |
+
if sig == "nvTmaDesc":
|
| 486 |
+
return True
|
| 487 |
+
if sig.startswith("tensordesc<"):
|
| 488 |
+
return True
|
| 489 |
+
return False
|
| 490 |
+
|
| 491 |
+
def process_tma_stable_arg(arg, arg_type, arg_signature, var_name):
|
| 492 |
+
# [Note: AOTI TMA Stable handling]
|
| 493 |
+
# For most args, a single arg passed to the python triton interface
|
| 494 |
+
# maps to a single arg in the cubin interface. However, for host-side
|
| 495 |
+
# TMA descriptors, a single python arg turns into 1 + 2 * N args in the
|
| 496 |
+
# cubin interface (where N is the rank).
|
| 497 |
+
#
|
| 498 |
+
# To do this: at TMA codegen time (for aoti), we generate a struct
|
| 499 |
+
# (StableTMADescriptor) containing the necessary information; and then
|
| 500 |
+
# when we call the function (i.e. here), we unpack the struct members.
|
| 501 |
+
code.writeline(f"auto {var_name} = {cexpr(arg)};")
|
| 502 |
+
|
| 503 |
+
result = []
|
| 504 |
+
result.append(f"&{var_name}.m")
|
| 505 |
+
|
| 506 |
+
# from https://github.com/triton-lang/triton/blob/16961b79bdac1b774b42d44e52fd55a266ec2866/third_party/nvidia/backend/driver.py#L111 # noqa: B950
|
| 507 |
+
match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", arg_signature)
|
| 508 |
+
assert match is not None
|
| 509 |
+
shape = match.group(2)
|
| 510 |
+
ndim = shape.count(",") + 1
|
| 511 |
+
|
| 512 |
+
for i in range(ndim):
|
| 513 |
+
result.append(f"&{var_name}.block_shape[{i}]")
|
| 514 |
+
|
| 515 |
+
for i in range(ndim):
|
| 516 |
+
result.append(f"&{var_name}.strides[{i}]")
|
| 517 |
+
|
| 518 |
+
return result
|
| 519 |
+
|
| 520 |
+
def process_args(arg, arg_type, arg_signature=None):
|
| 521 |
+
var_name = f"var_{next(self.arg_var_id)}"
|
| 522 |
+
# ignore tma descriptors, as host-side TMA descriptors need
|
| 523 |
+
# to be passed to the compiled Triton kernel by value
|
| 524 |
+
if isinstance(arg_type, UnwrapUnspecArg) and not signature_is_tma_desc(
|
| 525 |
+
arg_signature
|
| 526 |
+
):
|
| 527 |
+
self.codegen_tensor_item(
|
| 528 |
+
arg_type.dtype,
|
| 529 |
+
arg,
|
| 530 |
+
var_name,
|
| 531 |
+
indented_buffer=code,
|
| 532 |
+
)
|
| 533 |
+
new_args.append(f"&{var_name}")
|
| 534 |
+
elif isinstance(arg_type, torch_dtype) and not signature_is_tma_desc(
|
| 535 |
+
arg_signature
|
| 536 |
+
):
|
| 537 |
+
device_ptr_type = self.device_codegen.cpp_device_ptr()
|
| 538 |
+
code.writeline(
|
| 539 |
+
maybe_hipify_code_wrapper(
|
| 540 |
+
f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());"
|
| 541 |
+
)
|
| 542 |
+
)
|
| 543 |
+
new_args.append(f"&{var_name}")
|
| 544 |
+
elif arg_type in (sympy.Integer, int):
|
| 545 |
+
code.writeline(f"int {var_name} = {cexpr(arg)};")
|
| 546 |
+
new_args.append(f"&{var_name}")
|
| 547 |
+
elif arg_type in (sympy.Float, float):
|
| 548 |
+
code.writeline(f"float {var_name} = {cexpr(arg)};")
|
| 549 |
+
new_args.append(f"&{var_name}")
|
| 550 |
+
# For symbolic call arguments, examine the arg signatures from triton meta
|
| 551 |
+
# to explicitly cast to the right type
|
| 552 |
+
# Reason: `auto` can infer unexpected type against kernel input signature.
|
| 553 |
+
elif (
|
| 554 |
+
isinstance(arg_type, type(SymbolicCallArg))
|
| 555 |
+
and arg_signature is not None
|
| 556 |
+
and arg_signature in signature2dtype.keys()
|
| 557 |
+
):
|
| 558 |
+
code.writeline(
|
| 559 |
+
f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};"
|
| 560 |
+
)
|
| 561 |
+
new_args.append(f"&{var_name}")
|
| 562 |
+
elif arg_signature and arg_signature.startswith("tensordesc<"):
|
| 563 |
+
new_args.extend(
|
| 564 |
+
process_tma_stable_arg(arg, arg_type, arg_signature, var_name)
|
| 565 |
+
)
|
| 566 |
+
else:
|
| 567 |
+
code.writeline(f"auto {var_name} = {cexpr(arg)};")
|
| 568 |
+
new_args.append(f"&{var_name}")
|
| 569 |
+
|
| 570 |
+
for arg, arg_type, arg_signature in zip_longest(
|
| 571 |
+
call_args, arg_types, arg_signatures
|
| 572 |
+
):
|
| 573 |
+
process_args(arg, arg_type, arg_signature)
|
| 574 |
+
|
| 575 |
+
if (
|
| 576 |
+
is_triton_kernel
|
| 577 |
+
and (
|
| 578 |
+
global_scratch := self.device_codegen.cpp_global_scratch(
|
| 579 |
+
next(self.arg_var_id),
|
| 580 |
+
workspace=TritonScratchWorkspace(
|
| 581 |
+
size=workspace_size,
|
| 582 |
+
generate_dtype_str=(lambda: self.codegen_dtype(torch.uint8)),
|
| 583 |
+
),
|
| 584 |
+
)
|
| 585 |
+
)
|
| 586 |
+
is not None
|
| 587 |
+
):
|
| 588 |
+
global_scratch_def, global_scratch_var = global_scratch
|
| 589 |
+
code.writelines([maybe_hipify_code_wrapper(x) for x in global_scratch_def])
|
| 590 |
+
new_args.append(f"&{global_scratch_var}")
|
| 591 |
+
|
| 592 |
+
return ", ".join(new_args)
|
| 593 |
+
|
| 594 |
+
def _generate_kernel_call_helper(
|
| 595 |
+
self,
|
| 596 |
+
kernel_name: str,
|
| 597 |
+
call_args,
|
| 598 |
+
*,
|
| 599 |
+
device=None,
|
| 600 |
+
triton=True,
|
| 601 |
+
arg_types=None,
|
| 602 |
+
raw_keys=None,
|
| 603 |
+
raw_args=None,
|
| 604 |
+
triton_meta=None,
|
| 605 |
+
graph_name="",
|
| 606 |
+
original_fxnode_name=None,
|
| 607 |
+
):
|
| 608 |
+
"""
|
| 609 |
+
Override the default value of argument 'gpu' to True here.
|
| 610 |
+
generate_kernel_call can still be called with gpu=False because of
|
| 611 |
+
a mix of cpu kernels and gpu kernels.
|
| 612 |
+
"""
|
| 613 |
+
device = device or V.graph.get_current_device_or_throw()
|
| 614 |
+
if device.type == "cpu":
|
| 615 |
+
# Even in CppWrapperGpu, we may see cpp kernels
|
| 616 |
+
return CppWrapperCpu._generate_kernel_call_helper(
|
| 617 |
+
self,
|
| 618 |
+
kernel_name,
|
| 619 |
+
call_args,
|
| 620 |
+
device=device,
|
| 621 |
+
triton=triton,
|
| 622 |
+
arg_types=arg_types,
|
| 623 |
+
raw_keys=raw_keys,
|
| 624 |
+
raw_args=raw_args,
|
| 625 |
+
triton_meta=triton_meta,
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
if (
|
| 629 |
+
triton
|
| 630 |
+
and config.triton.autotune_at_compile_time
|
| 631 |
+
and kernel_name not in self.kernel_autotune_names
|
| 632 |
+
):
|
| 633 |
+
# Call PythonWrapperCodegen to create the autotune code block
|
| 634 |
+
PythonWrapperCodegen._generate_kernel_call_helper(
|
| 635 |
+
self,
|
| 636 |
+
kernel_name,
|
| 637 |
+
call_args,
|
| 638 |
+
device=device,
|
| 639 |
+
triton=triton,
|
| 640 |
+
arg_types=arg_types,
|
| 641 |
+
raw_keys=raw_keys,
|
| 642 |
+
raw_args=raw_args,
|
| 643 |
+
triton_meta=triton_meta,
|
| 644 |
+
original_fxnode_name=original_fxnode_name,
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
stream = (
|
| 648 |
+
"stream"
|
| 649 |
+
if V.graph.aot_mode
|
| 650 |
+
else self.write_get_raw_stream(device.index, graph_name)
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
if triton:
|
| 654 |
+
call_args, arg_types = self.prepare_triton_wrapper_args(
|
| 655 |
+
call_args, arg_types
|
| 656 |
+
)
|
| 657 |
+
wrapper_name = f"call_{kernel_name}"
|
| 658 |
+
if wrapper_name not in self._triton_call_wrappers:
|
| 659 |
+
self._triton_call_wrappers[wrapper_name] = DeferredTritonCallWrapper(
|
| 660 |
+
wrapper_name,
|
| 661 |
+
kernel_name,
|
| 662 |
+
self._kernel_name_to_body,
|
| 663 |
+
arg_types,
|
| 664 |
+
)
|
| 665 |
+
device_idx = "this->device_idx_" if V.graph.aot_mode else str(device.index)
|
| 666 |
+
call_args.append(device_idx)
|
| 667 |
+
call_args.append(stream)
|
| 668 |
+
if V.graph.aot_mode:
|
| 669 |
+
call_args.append("kernels")
|
| 670 |
+
call_args.append("this->cubin_dir_")
|
| 671 |
+
debug_printer_manager = V.graph.wrapper_code.debug_printer
|
| 672 |
+
debug_printer_manager.set_printer_args(
|
| 673 |
+
call_args[: len(arg_types)], kernel_name, arg_types, None
|
| 674 |
+
)
|
| 675 |
+
with debug_printer_manager:
|
| 676 |
+
self.writeline(f"{wrapper_name}({', '.join(call_args)});")
|
| 677 |
+
else:
|
| 678 |
+
casted = []
|
| 679 |
+
for arg_type, arg in zip(arg_types, call_args):
|
| 680 |
+
new_arg = arg
|
| 681 |
+
if arg_type.endswith("*") and arg != "nullptr":
|
| 682 |
+
new_arg = f"{arg}.data_ptr()"
|
| 683 |
+
casted.append(f"({arg_type}){cexpr(new_arg)}")
|
| 684 |
+
call_args_str = ", ".join(casted)
|
| 685 |
+
self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});")
|
| 686 |
+
|
| 687 |
+
@staticmethod
|
| 688 |
+
def prepare_triton_wrapper_args(
|
| 689 |
+
call_args: list[Any], arg_types: list[Any]
|
| 690 |
+
) -> tuple[list[Any], list[Any]]:
|
| 691 |
+
assert len(call_args) == len(arg_types), (call_args, arg_types)
|
| 692 |
+
new_args = []
|
| 693 |
+
new_args_types = []
|
| 694 |
+
for arg, arg_type in zip(call_args, arg_types):
|
| 695 |
+
if isinstance(arg, str):
|
| 696 |
+
if isinstance(arg_type, torch_dtype) and should_unwrap_unspec_arg(arg):
|
| 697 |
+
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
|
| 698 |
+
arg_type = UnwrapUnspecArg(dtype=arg_type)
|
| 699 |
+
new_args.append(arg)
|
| 700 |
+
elif isinstance(arg, bool):
|
| 701 |
+
new_args.append(str(arg).lower())
|
| 702 |
+
elif isinstance(arg, (int, float, SymbolicCallArg)):
|
| 703 |
+
new_args.append(str(arg))
|
| 704 |
+
else:
|
| 705 |
+
new_args.append(cexpr(V.graph.sizevars.simplify(arg)))
|
| 706 |
+
new_args_types.append(arg_type)
|
| 707 |
+
return new_args, new_args_types
|
| 708 |
+
|
| 709 |
+
def make_zero_buffer(self, name):
|
| 710 |
+
return f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get()));"
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
@dataclasses.dataclass
|
| 714 |
+
class UnwrapUnspecArg:
|
| 715 |
+
"""Marker that we need to call .item() on the tensor"""
|
| 716 |
+
|
| 717 |
+
dtype: torch_dtype
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_mps.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Optional
|
| 2 |
+
|
| 3 |
+
import sympy
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ..ir import GraphPartitionSignature
|
| 8 |
+
from ..virtualized import V
|
| 9 |
+
from .cpp_wrapper_gpu import CppWrapperGpu
|
| 10 |
+
from .wrapper import PythonWrapperCodegen
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CppWrapperMps(CppWrapperGpu):
|
| 14 |
+
@staticmethod
|
| 15 |
+
def create(
|
| 16 |
+
is_subgraph: bool,
|
| 17 |
+
subgraph_name: Optional[str],
|
| 18 |
+
parent_wrapper: Optional[PythonWrapperCodegen],
|
| 19 |
+
partition_signatures: Optional[GraphPartitionSignature] = None,
|
| 20 |
+
) -> "CppWrapperMps":
|
| 21 |
+
return CppWrapperMps()
|
| 22 |
+
|
| 23 |
+
def _generate_kernel_call_helper(
|
| 24 |
+
self,
|
| 25 |
+
kernel_name: str,
|
| 26 |
+
call_args: list[str],
|
| 27 |
+
arg_types: Optional[list[type]] = None,
|
| 28 |
+
**kwargs: dict[str, Any],
|
| 29 |
+
) -> None:
|
| 30 |
+
"""
|
| 31 |
+
Generates MPS kernel call code. It should look something like:
|
| 32 |
+
```
|
| 33 |
+
auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
|
| 34 |
+
auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
|
| 35 |
+
mps_lib_0_func->runCommandBlock([&] {
|
| 36 |
+
mps_lib_0_func->startEncoding();
|
| 37 |
+
aoti_torch_mps_set_arg(mps_lib_0_func_handle, 0, buf0);
|
| 38 |
+
aoti_torch_mps_set_arg(mps_lib_0_func_handle, 1, arg0_1);
|
| 39 |
+
...
|
| 40 |
+
mps_lib_0_func->dispatch(9);
|
| 41 |
+
});
|
| 42 |
+
```
|
| 43 |
+
"""
|
| 44 |
+
assert arg_types is not None
|
| 45 |
+
|
| 46 |
+
new_args = []
|
| 47 |
+
for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])):
|
| 48 |
+
if isinstance(arg_type, torch.dtype):
|
| 49 |
+
new_args.append(
|
| 50 |
+
f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});\n"
|
| 51 |
+
)
|
| 52 |
+
elif arg_type in (int, sympy.core.symbol.Symbol):
|
| 53 |
+
new_args.append(
|
| 54 |
+
f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});\n"
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
raise NotImplementedError(
|
| 58 |
+
f"Unsupported arg type {arg_type} for arg {arg} for kernel {kernel_name}"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
threads, group_size = call_args[-2], call_args[-1]
|
| 62 |
+
if threads is None:
|
| 63 |
+
raise NotImplementedError("No threads or group_size provided")
|
| 64 |
+
elif group_size is None:
|
| 65 |
+
new_args.append(f"{kernel_name}->dispatch({threads});\n")
|
| 66 |
+
else:
|
| 67 |
+
new_args.append(f"{kernel_name}->dispatch({threads}, {group_size});\n")
|
| 68 |
+
|
| 69 |
+
# debug printer related logic for cpp kernel type.
|
| 70 |
+
debug_printer_manager = V.graph.wrapper_code.debug_printer
|
| 71 |
+
debug_printer_manager.set_printer_args(
|
| 72 |
+
call_args[:-2],
|
| 73 |
+
kernel_name,
|
| 74 |
+
None,
|
| 75 |
+
None,
|
| 76 |
+
"cpp",
|
| 77 |
+
)
|
| 78 |
+
with debug_printer_manager:
|
| 79 |
+
self.writeline(self.wrap_kernel_call(kernel_name, new_args))
|
| 80 |
+
|
| 81 |
+
def wrap_kernel_call(self, name: str, call_args: list[str]) -> str:
|
| 82 |
+
lib_name = name[: -len("_func")]
|
| 83 |
+
calling_args = " ".join(call_args)
|
| 84 |
+
return f"""
|
| 85 |
+
auto {name} = {lib_name}.getKernelFunction("generated_kernel");
|
| 86 |
+
auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get());
|
| 87 |
+
{name}->runCommandBlock([&] {{
|
| 88 |
+
{name}->startEncoding();
|
| 89 |
+
{calling_args}
|
| 90 |
+
}});
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def get_device_include_path(device: str) -> str:
|
| 95 |
+
assert V.graph.aot_mode
|
| 96 |
+
return (
|
| 97 |
+
"#include <torch/csrc/inductor/aoti_include/mps.h>\n"
|
| 98 |
+
"#include <torch/csrc/inductor/aoti_torch/c/shim_mps.h>"
|
| 99 |
+
)
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpu_device_op_overrides.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from textwrap import dedent
|
| 4 |
+
|
| 5 |
+
from .common import DeviceOpOverrides, register_device_op_overrides
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CpuDeviceOpOverrides(DeviceOpOverrides):
|
| 9 |
+
def import_get_raw_stream_as(self, name: str) -> str:
|
| 10 |
+
return dedent(
|
| 11 |
+
"""
|
| 12 |
+
def get_raw_stream(_):
|
| 13 |
+
return 0
|
| 14 |
+
"""
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def set_device(self, device_idx: int) -> str:
|
| 18 |
+
return "pass"
|
| 19 |
+
|
| 20 |
+
def synchronize(self) -> str:
|
| 21 |
+
return "pass"
|
| 22 |
+
|
| 23 |
+
def device_guard(self, device_idx: int) -> str:
|
| 24 |
+
return "pass"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
register_device_op_overrides("cpu", CpuDeviceOpOverrides())
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import hashlib
|
| 3 |
+
import logging
|
| 4 |
+
from collections.abc import Sequence
|
| 5 |
+
from typing import cast
|
| 6 |
+
|
| 7 |
+
from torch._inductor.codegen.cuda.cutlass_python_evt import (
|
| 8 |
+
CutlassEVTCodegen,
|
| 9 |
+
MockCutlassHandler,
|
| 10 |
+
)
|
| 11 |
+
from torch._inductor.utils import Placeholder
|
| 12 |
+
from torch.utils._ordered_set import OrderedSet
|
| 13 |
+
|
| 14 |
+
from ...._dynamo.utils import counters
|
| 15 |
+
from ... import config
|
| 16 |
+
from ...codecache import code_hash, get_path
|
| 17 |
+
from ...ir import Buffer, ComputedBuffer, CUDATemplateBuffer, Pointwise
|
| 18 |
+
from ...scheduler import (
|
| 19 |
+
BaseSchedulerNode,
|
| 20 |
+
BaseScheduling,
|
| 21 |
+
FusedSchedulerNode,
|
| 22 |
+
SchedulerNode,
|
| 23 |
+
WhyNoFuse,
|
| 24 |
+
)
|
| 25 |
+
from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product
|
| 26 |
+
from ...virtualized import V
|
| 27 |
+
from ..common import BackendFeature, IndentedBuffer
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
log = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class WhyNoFuseNames(WhyNoFuse):
|
| 34 |
+
def __init__(self, name1: str, name2: str) -> None:
|
| 35 |
+
self.name1 = name1
|
| 36 |
+
self.name2 = name2
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class CUDACPPScheduling(BaseScheduling):
|
| 40 |
+
"""
|
| 41 |
+
Partial Scheduling implementation for CUDA C++ Kernels.
|
| 42 |
+
This class is intended to be used in combination with TritonScheduling,
|
| 43 |
+
and delegated to by CUDACombinedScheduling.
|
| 44 |
+
|
| 45 |
+
It handles fusion decisions and CUDA C++ specific template code generation.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def get_backend_features(cls, device) -> OrderedSet[BackendFeature]:
|
| 50 |
+
return OrderedSet()
|
| 51 |
+
|
| 52 |
+
def group_fn(self, sizes):
|
| 53 |
+
return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes)
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool:
|
| 57 |
+
return isinstance(node, SchedulerNode) and isinstance(
|
| 58 |
+
node.node, CUDATemplateBuffer
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool:
|
| 62 |
+
return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template(node)
|
| 63 |
+
|
| 64 |
+
def can_fuse_vertical(
|
| 65 |
+
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
| 66 |
+
) -> bool:
|
| 67 |
+
if self.is_cuda_cpp_template(node1) and isinstance(node2, BaseSchedulerNode):
|
| 68 |
+
assert node1.node, "node1.node should not be None"
|
| 69 |
+
return self._can_fuse_epilogue_impl(
|
| 70 |
+
cast(CUDATemplateBuffer, node1.node),
|
| 71 |
+
[],
|
| 72 |
+
node2, # type: ignore[arg-type]
|
| 73 |
+
)
|
| 74 |
+
elif self.is_cuda_cpp_fused_template(node1) and isinstance(
|
| 75 |
+
node2, BaseSchedulerNode
|
| 76 |
+
):
|
| 77 |
+
assert node1.node, "node1.node should not be None"
|
| 78 |
+
assert node2.node, "node2.node should not be None"
|
| 79 |
+
fnode1 = cast(FusedSchedulerNode, node1)
|
| 80 |
+
return self._can_fuse_epilogue_impl(
|
| 81 |
+
fnode1.get_template_node(), # type: ignore[arg-type]
|
| 82 |
+
self._unwrap_epilogue_nodes(fnode1),
|
| 83 |
+
node2, # type: ignore[arg-type]
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
return False
|
| 87 |
+
|
| 88 |
+
def define_kernel(self, src_code: str, node_schedule) -> str:
|
| 89 |
+
wrapper = V.graph.wrapper_code
|
| 90 |
+
if src_code in wrapper.src_to_kernel:
|
| 91 |
+
kernel_name = wrapper.src_to_kernel[src_code]
|
| 92 |
+
else:
|
| 93 |
+
fused_name = (
|
| 94 |
+
get_fused_kernel_name(node_schedule, config.triton.descriptive_names)
|
| 95 |
+
if config.triton.descriptive_names
|
| 96 |
+
else ""
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# use the original src_code as the key
|
| 100 |
+
kernel_hash = hashlib.sha256(src_code.encode("utf-8")).hexdigest()[:8]
|
| 101 |
+
if fused_name == "fused":
|
| 102 |
+
# no EVT kernel, use the original kernel name
|
| 103 |
+
kernel_name = f"cutlass_{kernel_hash}"
|
| 104 |
+
else:
|
| 105 |
+
kernel_name = f"cutlass_{fused_name}_{kernel_hash}"
|
| 106 |
+
wrapper.src_to_kernel[src_code] = kernel_name
|
| 107 |
+
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_name)
|
| 108 |
+
|
| 109 |
+
_, _, kernel_path = get_path(code_hash(src_code), "py")
|
| 110 |
+
|
| 111 |
+
compile_wrapper = IndentedBuffer()
|
| 112 |
+
compile_wrapper.writeline("async_compile.cuda(r'''")
|
| 113 |
+
compile_wrapper.splice(src_code, strip=True)
|
| 114 |
+
compile_wrapper.writeline(
|
| 115 |
+
f"''', 'so', aot_compile={str(V.graph.aot_mode)})"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
metadata_comment = f"# kernel path: {kernel_path}"
|
| 119 |
+
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
|
| 120 |
+
metadata_comment += "\n" + origins + "\n" + detailed_origins
|
| 121 |
+
wrapper.define_kernel(
|
| 122 |
+
kernel_name, compile_wrapper.getvalue(), metadata_comment
|
| 123 |
+
)
|
| 124 |
+
return kernel_name
|
| 125 |
+
|
| 126 |
+
def codegen_template(
|
| 127 |
+
self,
|
| 128 |
+
template_node: BaseSchedulerNode,
|
| 129 |
+
epilogue_nodes: Sequence[BaseSchedulerNode],
|
| 130 |
+
prologue_nodes: Sequence[BaseSchedulerNode],
|
| 131 |
+
):
|
| 132 |
+
"""
|
| 133 |
+
Codegen a CUDA template, possibly with fused epilogues
|
| 134 |
+
"""
|
| 135 |
+
counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes)
|
| 136 |
+
assert self.is_cuda_cpp_template(template_node), (
|
| 137 |
+
"Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
|
| 138 |
+
)
|
| 139 |
+
template_node = cast(SchedulerNode, template_node)
|
| 140 |
+
_, (_numel, rnumel) = template_node.group
|
| 141 |
+
assert rnumel == 1
|
| 142 |
+
ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node)
|
| 143 |
+
epilogue_ir_nodes: list[Buffer] = [n.node for n in epilogue_nodes] # type: ignore[misc]
|
| 144 |
+
assert all(isinstance(n, ComputedBuffer) for n in epilogue_ir_nodes), (
|
| 145 |
+
"Epilogue nodes must all be instances of ir.ComputedBuffer"
|
| 146 |
+
)
|
| 147 |
+
kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_nodes)
|
| 148 |
+
|
| 149 |
+
with kernel:
|
| 150 |
+
for node in [template_node, *epilogue_nodes]:
|
| 151 |
+
node.mark_run()
|
| 152 |
+
|
| 153 |
+
# typically there is a codegen pass which runs after mark_run
|
| 154 |
+
# for this kernel we've already generated the C++ code, but we still
|
| 155 |
+
# need to let the kernel know about loads/stores that occur in the fused
|
| 156 |
+
# kernel for memory planning to properly optimize allocations
|
| 157 |
+
ctb.emulate_store_fn()
|
| 158 |
+
for node in epilogue_ir_nodes:
|
| 159 |
+
with V.set_ops_handler(MockCutlassHandler(V.get_ops_handler())):
|
| 160 |
+
assert isinstance(
|
| 161 |
+
node, ComputedBuffer
|
| 162 |
+
) # Not sure why we need to do this again
|
| 163 |
+
node.get_store_function()(CutlassEVTCodegen.get_index_vars(node))
|
| 164 |
+
|
| 165 |
+
with V.set_kernel_handler(kernel):
|
| 166 |
+
src_code = render()
|
| 167 |
+
node_schedule = [template_node, *epilogue_nodes]
|
| 168 |
+
kernel_name = self.define_kernel(src_code, node_schedule)
|
| 169 |
+
|
| 170 |
+
# debug printing values of intermediate tensors
|
| 171 |
+
_, call_args, arg_signatures, _ = kernel.args.python_argdefs()
|
| 172 |
+
debug_printer_manager = V.graph.wrapper_code.debug_printer
|
| 173 |
+
debug_printer_manager.set_printer_args(
|
| 174 |
+
call_args, kernel_name, arg_signatures, kernel
|
| 175 |
+
)
|
| 176 |
+
with debug_printer_manager:
|
| 177 |
+
kernel.call_kernel(kernel_name, ctb)
|
| 178 |
+
|
| 179 |
+
V.graph.removed_buffers |= kernel.removed_buffers
|
| 180 |
+
self.free_buffers_in_scheduler()
|
| 181 |
+
|
| 182 |
+
@staticmethod
|
| 183 |
+
def _unwrap_epilogue_nodes(
|
| 184 |
+
fused_node: FusedSchedulerNode,
|
| 185 |
+
) -> list[BaseSchedulerNode]:
|
| 186 |
+
nodes = fused_node.get_nodes()
|
| 187 |
+
template_node = fused_node.get_template_node()
|
| 188 |
+
assert all(n.node is not None for n in nodes), (
|
| 189 |
+
"All epilogue nodes should have an IRNode"
|
| 190 |
+
)
|
| 191 |
+
return cast(
|
| 192 |
+
list[BaseSchedulerNode], [n for n in nodes if n.node is not template_node]
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def _can_fuse_epilogue_impl(
|
| 196 |
+
self,
|
| 197 |
+
cuda_template_buffer: CUDATemplateBuffer,
|
| 198 |
+
existing_epilogue_nodes: list[BaseSchedulerNode],
|
| 199 |
+
node_to_fuse: BaseSchedulerNode,
|
| 200 |
+
) -> bool:
|
| 201 |
+
"""
|
| 202 |
+
Check if the given node can be fused with the epilogue. At the moment, Kernels
|
| 203 |
+
support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer
|
| 207 |
+
existing_epilogue_nodes : List[SchedulerNode]: The list of already fused epilogue nodes.
|
| 208 |
+
node_to_fuse: The SchedulerNode node to be checked if it can be fused with the epilogue.
|
| 209 |
+
Returns:
|
| 210 |
+
- bool: True if the given node can be fused with the epilogue, False otherwise.
|
| 211 |
+
|
| 212 |
+
"""
|
| 213 |
+
why = WhyNoFuseNames(cuda_template_buffer.get_name(), node_to_fuse.get_name())
|
| 214 |
+
|
| 215 |
+
scheduler_nodes_to_fuse = node_to_fuse.get_nodes()
|
| 216 |
+
|
| 217 |
+
assert isinstance(cuda_template_buffer, CUDATemplateBuffer)
|
| 218 |
+
|
| 219 |
+
# Checks on constituent nodes
|
| 220 |
+
for s_node in scheduler_nodes_to_fuse:
|
| 221 |
+
node = s_node.node
|
| 222 |
+
|
| 223 |
+
if not isinstance(node, ComputedBuffer):
|
| 224 |
+
why(f"{node} is not a ComputedBuffer")
|
| 225 |
+
return False
|
| 226 |
+
elif not isinstance(node.data, Pointwise):
|
| 227 |
+
why(f"{node} is not a Pointwise op")
|
| 228 |
+
return False
|
| 229 |
+
elif not node.get_computed_buffer_name(): # type: ignore[attr-defined]
|
| 230 |
+
why(f"{node} does not have a computed buffer name")
|
| 231 |
+
return False
|
| 232 |
+
|
| 233 |
+
name = node.get_computed_buffer_name() # type: ignore[attr-defined]
|
| 234 |
+
# dtype can differ, and strides can differ as long as they are broadcastable
|
| 235 |
+
if node.get_size() != cuda_template_buffer.get_size():
|
| 236 |
+
why(
|
| 237 |
+
f"{name}'s size: {node.get_size()} differs from {cuda_template_buffer.get_name()}'s \
|
| 238 |
+
size: {cuda_template_buffer.get_size()}"
|
| 239 |
+
)
|
| 240 |
+
return False
|
| 241 |
+
|
| 242 |
+
assert len(
|
| 243 |
+
existing_epilogue_nodes
|
| 244 |
+
) or cuda_template_buffer.get_name() in OrderedSet(
|
| 245 |
+
[rd.name for rd in node_to_fuse.read_writes.reads]
|
| 246 |
+
), "First epilogue node must read from cuda template buffer"
|
| 247 |
+
|
| 248 |
+
if node_to_fuse.has_aliasing_or_mutation():
|
| 249 |
+
why(f"{node_to_fuse.get_name()} has aliasing or mutation")
|
| 250 |
+
return False
|
| 251 |
+
elif node_to_fuse.is_reduction():
|
| 252 |
+
why(
|
| 253 |
+
f"{node_to_fuse.get_name()} is a reduction which is not yet supported by EVT"
|
| 254 |
+
)
|
| 255 |
+
return False
|
| 256 |
+
elif (
|
| 257 |
+
not config.cuda.cutlass_epilogue_fusion_enabled
|
| 258 |
+
or not config.epilogue_fusion
|
| 259 |
+
):
|
| 260 |
+
why("cutlass epilogue fusion is not enabled")
|
| 261 |
+
return False
|
| 262 |
+
elif not cuda_template_buffer.supports_epilogue_fusion:
|
| 263 |
+
why("epilogue fusion is only supported for TMA-enabled gemm ops")
|
| 264 |
+
return False
|
| 265 |
+
|
| 266 |
+
try:
|
| 267 |
+
from torch._inductor.codegen.cuda.cutlass_python_evt import (
|
| 268 |
+
CutlassEVTCodegen,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
CutlassEVTCodegen.ir_to_evt_python_code(
|
| 272 |
+
cuda_template_buffer.get_name(),
|
| 273 |
+
existing_epilogue_nodes + list(node_to_fuse.get_nodes()),
|
| 274 |
+
OrderedSet(),
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
except NotImplementedError as e:
|
| 278 |
+
not_implemented_op = str(e)
|
| 279 |
+
if not_implemented_op.startswith("_op_"):
|
| 280 |
+
not_implemented_op = not_implemented_op[4:]
|
| 281 |
+
why(
|
| 282 |
+
f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}, \
|
| 283 |
+
likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950
|
| 284 |
+
)
|
| 285 |
+
return False
|
| 286 |
+
else: # Likely due to unsupported dtype.
|
| 287 |
+
why(
|
| 288 |
+
f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}. \
|
| 289 |
+
Reason: {not_implemented_op}" # noqa: G004, B950
|
| 290 |
+
)
|
| 291 |
+
return False
|
| 292 |
+
|
| 293 |
+
return True
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_env.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import logging
|
| 3 |
+
import shutil
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch._inductor.utils import clear_on_fresh_cache
|
| 8 |
+
|
| 9 |
+
from ... import config
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
log = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@clear_on_fresh_cache
|
| 16 |
+
@functools.lru_cache(1)
|
| 17 |
+
def get_cuda_arch() -> Optional[str]:
|
| 18 |
+
try:
|
| 19 |
+
cuda_arch = config.cuda.arch
|
| 20 |
+
if cuda_arch is None:
|
| 21 |
+
# Get Compute Capability of the first Visible device
|
| 22 |
+
major, minor = torch.cuda.get_device_capability(0)
|
| 23 |
+
return str(major * 10 + minor)
|
| 24 |
+
return str(cuda_arch)
|
| 25 |
+
except Exception as e:
|
| 26 |
+
log.error("Error getting cuda arch: %s", e)
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@clear_on_fresh_cache
|
| 31 |
+
@functools.lru_cache(1)
|
| 32 |
+
def get_cuda_version() -> Optional[str]:
|
| 33 |
+
try:
|
| 34 |
+
cuda_version = config.cuda.version
|
| 35 |
+
if cuda_version is None:
|
| 36 |
+
cuda_version = torch.version.cuda
|
| 37 |
+
return cuda_version
|
| 38 |
+
except Exception as e:
|
| 39 |
+
log.error("Error getting cuda version: %s", e)
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@functools.cache
|
| 44 |
+
def nvcc_exist(nvcc_path: Optional[str] = "nvcc") -> bool:
|
| 45 |
+
return nvcc_path is not None and shutil.which(nvcc_path) is not None
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py
ADDED
|
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
import itertools
|
| 4 |
+
import logging
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
|
| 8 |
+
|
| 9 |
+
from sympy import Expr, symbols
|
| 10 |
+
|
| 11 |
+
import torch._inductor.config as config
|
| 12 |
+
from torch import dtype as torch_dtype
|
| 13 |
+
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
|
| 14 |
+
from torch._inductor.scheduler import BaseSchedulerNode
|
| 15 |
+
from torch._inductor.utils import do_bench_using_profiling, OrderedSet, Placeholder
|
| 16 |
+
from torch.utils._sympy.value_ranges import ValueRanges
|
| 17 |
+
|
| 18 |
+
from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if TYPE_CHECKING:
|
| 22 |
+
from .cuda_template import ArgInfo
|
| 23 |
+
|
| 24 |
+
from ...autotune_process import CUDABenchmarkRequest
|
| 25 |
+
from ...ir import (
|
| 26 |
+
Buffer,
|
| 27 |
+
ChoiceCaller,
|
| 28 |
+
CUDATemplateBuffer,
|
| 29 |
+
IRNode,
|
| 30 |
+
Layout,
|
| 31 |
+
PrimitiveInfoType,
|
| 32 |
+
TensorBox,
|
| 33 |
+
)
|
| 34 |
+
from ...utils import sympy_product
|
| 35 |
+
from ...virtualized import V
|
| 36 |
+
from ..common import (
|
| 37 |
+
CSEVariable,
|
| 38 |
+
IndentedBuffer,
|
| 39 |
+
Kernel,
|
| 40 |
+
OpOverrides,
|
| 41 |
+
WorkspaceArg,
|
| 42 |
+
WorkspaceZeroMode,
|
| 43 |
+
)
|
| 44 |
+
from ..cpp_utils import CppPrinter, DTYPE_TO_CPP
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if TYPE_CHECKING:
|
| 48 |
+
from torch._inductor.codegen.cuda.cuda_template import CUDATemplate
|
| 49 |
+
|
| 50 |
+
log = logging.getLogger(__name__)
|
| 51 |
+
|
| 52 |
+
cexpr = CppPrinter().doprint
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _normalize_idx(index: int, total_length: int) -> int:
|
| 56 |
+
return index if index >= 0 else index + total_length
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
ValidLayoutSymbols = Literal["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"]
|
| 60 |
+
ValidLayoutAttrs = Literal["size", "stride"]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass(frozen=True)
|
| 64 |
+
class LayoutArg:
|
| 65 |
+
node: IRNode
|
| 66 |
+
symbol: ValidLayoutSymbols
|
| 67 |
+
attr: ValidLayoutAttrs
|
| 68 |
+
dim: int
|
| 69 |
+
|
| 70 |
+
def matches(self, node, attr, dim) -> bool:
|
| 71 |
+
return self.node == node and self.attr == attr and self.dim == dim
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class CUDAKernel(Kernel):
|
| 75 |
+
"""
|
| 76 |
+
Baseclass for CUDA / Cutlass based Kernels
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
overrides = OpOverrides # type: ignore[assignment]
|
| 80 |
+
|
| 81 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 82 |
+
super().__init__(*args, **kwargs)
|
| 83 |
+
self.layout_args: dict[str, list[LayoutArg]] = defaultdict(list)
|
| 84 |
+
self.size_args: list[Union[Expr, int]] = []
|
| 85 |
+
# Mapping from arg name to IRNode.
|
| 86 |
+
self.named_nodes: dict[str, IRNode] = {}
|
| 87 |
+
|
| 88 |
+
def find_symbol(
|
| 89 |
+
self, node: IRNode, attr: ValidLayoutAttrs, dim: int
|
| 90 |
+
) -> Optional[str]:
|
| 91 |
+
arg = self.find_layout_arg(node, attr, dim)
|
| 92 |
+
return arg.symbol if arg else None
|
| 93 |
+
|
| 94 |
+
def find_layout_arg(
|
| 95 |
+
self, node: IRNode, attr: ValidLayoutAttrs, dim: int
|
| 96 |
+
) -> Optional[LayoutArg]:
|
| 97 |
+
matches = [
|
| 98 |
+
arg
|
| 99 |
+
for arg in itertools.chain.from_iterable(self.layout_args.values())
|
| 100 |
+
if arg.matches(node, attr, dim)
|
| 101 |
+
]
|
| 102 |
+
if len(matches) >= 1:
|
| 103 |
+
# Verify all matches have the same node, attribute, and dimension
|
| 104 |
+
# And if they come from the same node, whichever symbol we use is fine.
|
| 105 |
+
# if in runtime the logic changes, this would trigger guard
|
| 106 |
+
first_match = matches[0]
|
| 107 |
+
if not all(
|
| 108 |
+
match.node == first_match.node
|
| 109 |
+
and match.attr == first_match.attr
|
| 110 |
+
and match.dim == first_match.dim
|
| 111 |
+
for match in matches
|
| 112 |
+
):
|
| 113 |
+
raise AssertionError("All matching layout args should be identical")
|
| 114 |
+
return first_match
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
def add_layout_arg(
|
| 118 |
+
self, symbol: ValidLayoutSymbols, node: IRNode, attr: ValidLayoutAttrs, dim: int
|
| 119 |
+
):
|
| 120 |
+
arg = LayoutArg(node, symbol, attr, dim)
|
| 121 |
+
self.layout_args[symbol].append(arg)
|
| 122 |
+
|
| 123 |
+
def init_layout_args(self) -> None:
|
| 124 |
+
X = self.named_nodes["X"]
|
| 125 |
+
W = self.named_nodes["W"]
|
| 126 |
+
Y = self.named_nodes["Y"]
|
| 127 |
+
Bias = self.named_nodes.get("Bias", None)
|
| 128 |
+
x_mdim = _normalize_idx(-2, len(X.get_size()))
|
| 129 |
+
x_kdim = _normalize_idx(-1, len(X.get_size()))
|
| 130 |
+
w_kdim = _normalize_idx(-2, len(W.get_size()))
|
| 131 |
+
w_ndim = _normalize_idx(-1, len(W.get_size()))
|
| 132 |
+
y_mdim = _normalize_idx(-2, len(Y.get_size()))
|
| 133 |
+
y_ndim = _normalize_idx(-1, len(Y.get_size()))
|
| 134 |
+
self.add_layout_arg("M", X, "size", x_mdim)
|
| 135 |
+
self.add_layout_arg("K", X, "size", x_kdim)
|
| 136 |
+
self.add_layout_arg("K", W, "size", w_kdim)
|
| 137 |
+
self.add_layout_arg("N", W, "size", w_ndim)
|
| 138 |
+
self.add_layout_arg("M", Y, "size", y_mdim)
|
| 139 |
+
self.add_layout_arg("N", Y, "size", y_ndim)
|
| 140 |
+
if len(X.get_size()) > 2:
|
| 141 |
+
self.add_layout_arg("B", X, "size", 0)
|
| 142 |
+
|
| 143 |
+
lda_dim = self.find_ld_idx(X)
|
| 144 |
+
ldb_dim = self.find_ld_idx(W)
|
| 145 |
+
ldc_dim = self.find_ld_idx(Bias) if Bias else None
|
| 146 |
+
ldd_dim = self.find_ld_idx(Y)
|
| 147 |
+
self.add_layout_arg("lda", X, "stride", lda_dim)
|
| 148 |
+
self.add_layout_arg("ldb", W, "stride", ldb_dim)
|
| 149 |
+
if Bias is not None and ldc_dim is not None:
|
| 150 |
+
self.add_layout_arg("ldc", Bias, "stride", ldc_dim)
|
| 151 |
+
self.add_layout_arg("ldd", Y, "stride", ldd_dim)
|
| 152 |
+
|
| 153 |
+
def get_layout_args(self) -> tuple[Union[Expr, int], ...]:
|
| 154 |
+
X = self.named_nodes["X"]
|
| 155 |
+
W = self.named_nodes["W"]
|
| 156 |
+
Y = self.named_nodes["Y"]
|
| 157 |
+
Bias = self.named_nodes.get("Bias", None)
|
| 158 |
+
mdim = _normalize_idx(-2, len(X.get_size()))
|
| 159 |
+
ndim = _normalize_idx(-1, len(W.get_size()))
|
| 160 |
+
kdim = _normalize_idx(-1, len(X.get_size()))
|
| 161 |
+
|
| 162 |
+
def get_ld(node) -> Union[Expr, int]:
|
| 163 |
+
dim = self.find_ld_idx(node)
|
| 164 |
+
return node.get_stride()[dim]
|
| 165 |
+
|
| 166 |
+
M = X.get_size()[mdim]
|
| 167 |
+
N = W.get_size()[ndim]
|
| 168 |
+
K = X.get_size()[kdim]
|
| 169 |
+
B = X.get_size()[0] if len(X.get_size()) > 2 else 1
|
| 170 |
+
LDA = get_ld(X)
|
| 171 |
+
LDB = get_ld(W)
|
| 172 |
+
LDC = get_ld(Bias) if Bias else 0
|
| 173 |
+
LDD = get_ld(Y)
|
| 174 |
+
return (M, N, K, B, LDA, LDB, LDC, LDD)
|
| 175 |
+
|
| 176 |
+
def get_dynamic_shape_args(self) -> list[Union[Expr, int]]:
|
| 177 |
+
return [*self.get_layout_args(), *self.size_args]
|
| 178 |
+
|
| 179 |
+
@staticmethod
|
| 180 |
+
def find_ld_idx(node: IRNode) -> int:
|
| 181 |
+
strides = node.get_stride()
|
| 182 |
+
# Handle 1D tensor case
|
| 183 |
+
if V.graph.sizevars.statically_known_equals(strides[-1], 1):
|
| 184 |
+
return _normalize_idx(-2, len(strides))
|
| 185 |
+
|
| 186 |
+
assert V.graph.sizevars.statically_known_equals(strides[-2], 1), strides[-2]
|
| 187 |
+
return _normalize_idx(-1, len(strides))
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class CUDATemplateKernel(CUDAKernel):
|
| 191 |
+
"""
|
| 192 |
+
Template kernels defined by CUDA / Cutlass in C++.
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
_EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream"
|
| 196 |
+
|
| 197 |
+
def __init__(
|
| 198 |
+
self,
|
| 199 |
+
kernel_name: str,
|
| 200 |
+
runtime_arg_info: list["ArgInfo"],
|
| 201 |
+
runtime_arg_values: list[Any],
|
| 202 |
+
) -> None:
|
| 203 |
+
"""
|
| 204 |
+
Initializes a new instance of the CUDATemplateKernel class.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
kernel_name (str): The name of the kernel.
|
| 208 |
+
"""
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.kernel_name = kernel_name
|
| 211 |
+
self.runtime_arg_info = runtime_arg_info
|
| 212 |
+
self.runtime_arg_values = runtime_arg_values
|
| 213 |
+
|
| 214 |
+
def check_not_null(self, node: IRNode) -> str:
|
| 215 |
+
"""
|
| 216 |
+
Generates code to check that a node is not null.
|
| 217 |
+
"""
|
| 218 |
+
if node is None:
|
| 219 |
+
return ""
|
| 220 |
+
|
| 221 |
+
size_str = self.size(node, 0, -1)
|
| 222 |
+
name_str = self.arg_name(node)
|
| 223 |
+
if name_str is None:
|
| 224 |
+
return ""
|
| 225 |
+
|
| 226 |
+
res = IndentedBuffer(initial_indent=2)
|
| 227 |
+
res.tabwidth = 1
|
| 228 |
+
res.splice(
|
| 229 |
+
f"""
|
| 230 |
+
{{
|
| 231 |
+
if (!{name_str}) {{
|
| 232 |
+
int64_t {name_str}_size = {size_str};
|
| 233 |
+
if ({name_str}_size > 0) {{
|
| 234 |
+
throw std::runtime_error("input {name_str} is null but size is not 0!");
|
| 235 |
+
}}
|
| 236 |
+
}}
|
| 237 |
+
}}
|
| 238 |
+
"""
|
| 239 |
+
)
|
| 240 |
+
return res.getvalue()
|
| 241 |
+
|
| 242 |
+
def get_signature(self) -> str:
|
| 243 |
+
return self.signature
|
| 244 |
+
|
| 245 |
+
def def_kernel(
|
| 246 |
+
self,
|
| 247 |
+
inputs: list[IRNode],
|
| 248 |
+
outputs: list[IRNode],
|
| 249 |
+
names_str: str = "",
|
| 250 |
+
input_reorder: Optional[list[int]] = None,
|
| 251 |
+
) -> str:
|
| 252 |
+
"""
|
| 253 |
+
Hook called from template code to generate function definition and
|
| 254 |
+
needed args.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
inputs: List of input IRNodes
|
| 258 |
+
outputs: List of output IRNodes
|
| 259 |
+
names_str: Comma separated list of input + output argument names.
|
| 260 |
+
input_reorder: The actual order of input nodes.
|
| 261 |
+
e.g. The template might have input argument defined as [X, W, Bias],
|
| 262 |
+
and the actual input passed into this template could be [Bias, X, W].
|
| 263 |
+
In this case, the `input_reorder` would be [2, 0, 1].
|
| 264 |
+
additional_size_args: Additional size arguments for epilogue inputs
|
| 265 |
+
"""
|
| 266 |
+
names = [x.strip() for x in names_str.strip().split(",")]
|
| 267 |
+
if len(inputs) + len(outputs) != len(names):
|
| 268 |
+
raise RuntimeError(
|
| 269 |
+
f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}"
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
if input_reorder is not None:
|
| 273 |
+
assert len(inputs) == len(input_reorder)
|
| 274 |
+
else:
|
| 275 |
+
input_reorder = list(range(len(inputs)))
|
| 276 |
+
|
| 277 |
+
for idx in input_reorder:
|
| 278 |
+
name = names[idx]
|
| 279 |
+
node = inputs[idx]
|
| 280 |
+
if node is not None:
|
| 281 |
+
self.named_nodes[name] = node
|
| 282 |
+
self.args.input_buffers[node.get_name()] = name
|
| 283 |
+
|
| 284 |
+
free_symbols: OrderedSet[Expr] = OrderedSet()
|
| 285 |
+
for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs):
|
| 286 |
+
if node is not None:
|
| 287 |
+
self.named_nodes[name] = node
|
| 288 |
+
self.args.output_buffers[node.get_name()] = name
|
| 289 |
+
|
| 290 |
+
if name not in (
|
| 291 |
+
"X",
|
| 292 |
+
"W",
|
| 293 |
+
"Bias",
|
| 294 |
+
"Y",
|
| 295 |
+
): # we handle these symbolic shapes explicitly
|
| 296 |
+
for expr in itertools.chain(node.get_size(), node.get_stride()):
|
| 297 |
+
if isinstance(expr, Expr):
|
| 298 |
+
for s in expr.free_symbols:
|
| 299 |
+
free_symbols.add(s) # type: ignore[arg-type]
|
| 300 |
+
|
| 301 |
+
arg_defs, *_ = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE)
|
| 302 |
+
|
| 303 |
+
self.init_layout_args()
|
| 304 |
+
size_vars = ["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"]
|
| 305 |
+
size_vars.extend(str(s) for s in free_symbols)
|
| 306 |
+
self.size_args.extend(free_symbols)
|
| 307 |
+
size_args = [f"const int {s}" for s in size_vars]
|
| 308 |
+
|
| 309 |
+
runtime_arg_decls = ",".join(
|
| 310 |
+
[f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info]
|
| 311 |
+
)
|
| 312 |
+
if runtime_arg_decls:
|
| 313 |
+
runtime_arg_decls += ", "
|
| 314 |
+
|
| 315 |
+
signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args)}, {runtime_arg_decls}{self._EXTRA_CPP_ARGS})"
|
| 316 |
+
self.signature = signature
|
| 317 |
+
return signature
|
| 318 |
+
|
| 319 |
+
def call_kernel(
|
| 320 |
+
self,
|
| 321 |
+
name: str,
|
| 322 |
+
node: "CUDATemplateBuffer", # type: ignore[name-defined]
|
| 323 |
+
) -> None:
|
| 324 |
+
"""
|
| 325 |
+
Generates code to call the kernel through V.graph.wrapper_code.
|
| 326 |
+
used from within torch._inductor.wrapper.PythonWrapperCodegen
|
| 327 |
+
|
| 328 |
+
name: Name of kernel function.
|
| 329 |
+
node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes
|
| 330 |
+
as well as all required inputs and outputs.
|
| 331 |
+
"""
|
| 332 |
+
wrapper = V.graph.wrapper_code
|
| 333 |
+
|
| 334 |
+
arg_types: list[Any]
|
| 335 |
+
if V.graph.cpp_wrapper:
|
| 336 |
+
# Make sure we initialize these kernels since they're exported as
|
| 337 |
+
# C-style symbol names.
|
| 338 |
+
assert isinstance(wrapper, CppWrapperCpu)
|
| 339 |
+
wrapper.initialized_kernels[name] = self
|
| 340 |
+
# We always originally initialize name with "KERNEL_NAME". So, we
|
| 341 |
+
# we replace with the real kernel name passed as an arg to this function.
|
| 342 |
+
self.signature = self.signature.replace(str(Placeholder.KERNEL_NAME), name)
|
| 343 |
+
_, call_args, arg_types = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE)
|
| 344 |
+
else:
|
| 345 |
+
_, call_args, _, arg_types = self.args.python_argdefs()
|
| 346 |
+
|
| 347 |
+
dynamic_shape_args = self.get_dynamic_shape_args()
|
| 348 |
+
call_args.extend(dynamic_shape_args) # type: ignore[arg-type]
|
| 349 |
+
for arg in self.runtime_arg_values:
|
| 350 |
+
call_args.append(arg)
|
| 351 |
+
arg_types.extend("int" for _ in dynamic_shape_args)
|
| 352 |
+
for arg in self.runtime_arg_info:
|
| 353 |
+
arg_types.append(arg.ty)
|
| 354 |
+
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
|
| 355 |
+
for i in range(len(call_args)):
|
| 356 |
+
if V.graph.is_unspec_arg(call_args[i]):
|
| 357 |
+
call_args[i] = call_args[i] + ".item()"
|
| 358 |
+
elif isinstance(arg_types[i], torch_dtype):
|
| 359 |
+
call_args[i] = (
|
| 360 |
+
call_args[i]
|
| 361 |
+
if V.graph.cpp_wrapper
|
| 362 |
+
else f"c_void_p({call_args[i]}.data_ptr())"
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size.
|
| 366 |
+
# workspace_size should have already been retrieved prior to this call.
|
| 367 |
+
# workspace_size is here.
|
| 368 |
+
call_args.append("nullptr" if V.graph.cpp_wrapper else "None")
|
| 369 |
+
if V.graph.cpp_wrapper:
|
| 370 |
+
arg_types.append("size_t*")
|
| 371 |
+
|
| 372 |
+
if node.get_workspace_size() > 0:
|
| 373 |
+
ws = WorkspaceArg(
|
| 374 |
+
count=node.get_workspace_size(),
|
| 375 |
+
device=V.graph.get_current_device_or_throw(),
|
| 376 |
+
zero_mode=WorkspaceZeroMode.UNINITIALIZED,
|
| 377 |
+
outer_name=WorkspaceArg.unique_name(),
|
| 378 |
+
)
|
| 379 |
+
wrapper.generate_workspace_allocation(ws)
|
| 380 |
+
workspace = str(ws.outer_name)
|
| 381 |
+
call_args.append(
|
| 382 |
+
workspace
|
| 383 |
+
if V.graph.cpp_wrapper
|
| 384 |
+
else f"c_void_p({workspace}.data_ptr())"
|
| 385 |
+
)
|
| 386 |
+
else:
|
| 387 |
+
ws = None
|
| 388 |
+
call_args.append("nullptr" if V.graph.cpp_wrapper else "None")
|
| 389 |
+
if V.graph.cpp_wrapper:
|
| 390 |
+
arg_types.append("uint8_t*")
|
| 391 |
+
|
| 392 |
+
wrapper.generate_kernel_call(
|
| 393 |
+
name,
|
| 394 |
+
call_args,
|
| 395 |
+
triton=False,
|
| 396 |
+
arg_types=arg_types,
|
| 397 |
+
)
|
| 398 |
+
if ws:
|
| 399 |
+
wrapper.generate_workspace_deallocation(ws)
|
| 400 |
+
|
| 401 |
+
def dtype(self, node: IRNode) -> Optional[str]:
|
| 402 |
+
"""
|
| 403 |
+
Generates code which represents dtype of a given node.
|
| 404 |
+
"""
|
| 405 |
+
|
| 406 |
+
if node is None:
|
| 407 |
+
return "void"
|
| 408 |
+
return DTYPE_TO_CPP.get(node.get_layout().dtype)
|
| 409 |
+
|
| 410 |
+
def cutlass_dtype(self, node: IRNode, default_dtype="void") -> Optional[str]:
|
| 411 |
+
# Helper method, called into from CUTLASSGemmTemplate
|
| 412 |
+
if node is None:
|
| 413 |
+
return default_dtype
|
| 414 |
+
from torch._inductor.codegen.cuda.cuda_template import CUTLASSTemplate
|
| 415 |
+
|
| 416 |
+
return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]
|
| 417 |
+
|
| 418 |
+
def max_valid_index(self, node: IRNode, default=-1):
|
| 419 |
+
# Helper method, called into from CUTLASSGemmTemplate
|
| 420 |
+
if node is None:
|
| 421 |
+
return default
|
| 422 |
+
max_valid_offset = 0
|
| 423 |
+
for i in range(len(node.get_size())):
|
| 424 |
+
max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i]
|
| 425 |
+
return max_valid_offset
|
| 426 |
+
|
| 427 |
+
def offset(self, node: IRNode) -> str:
|
| 428 |
+
"""
|
| 429 |
+
Generates code which represents offset of a given node.
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
if node is None:
|
| 433 |
+
return "0"
|
| 434 |
+
return str(node.get_layout().offset) # type: ignore[union-attr]
|
| 435 |
+
|
| 436 |
+
def ptr(self, node: IRNode) -> str:
|
| 437 |
+
"""
|
| 438 |
+
Generates code which represents pointer of a given node.
|
| 439 |
+
"""
|
| 440 |
+
|
| 441 |
+
if node is None:
|
| 442 |
+
return "nullptr"
|
| 443 |
+
arg_name = self.arg_name(node)
|
| 444 |
+
if arg_name is None:
|
| 445 |
+
return "nullptr"
|
| 446 |
+
offset = self.offset(node)
|
| 447 |
+
return arg_name if offset == "0" else f"{arg_name} + {offset}"
|
| 448 |
+
|
| 449 |
+
def size(
|
| 450 |
+
self,
|
| 451 |
+
node: IRNode,
|
| 452 |
+
start_index: int,
|
| 453 |
+
end_index: Optional[int] = None,
|
| 454 |
+
default_value: int = 0,
|
| 455 |
+
) -> str:
|
| 456 |
+
"""
|
| 457 |
+
Hook called from template code to get the size of an arg.
|
| 458 |
+
Generates code which represents size of a given node in [start_index, end_index).
|
| 459 |
+
If node is None, returns default_value.
|
| 460 |
+
|
| 461 |
+
TODO: Will add needed args to pass it in if it is dynamic.
|
| 462 |
+
"""
|
| 463 |
+
|
| 464 |
+
if node is None:
|
| 465 |
+
return str(default_value)
|
| 466 |
+
|
| 467 |
+
start_index = _normalize_idx(start_index, len(node.get_size()))
|
| 468 |
+
if end_index is None:
|
| 469 |
+
end_index = start_index
|
| 470 |
+
end_index = _normalize_idx(end_index, len(node.get_size()))
|
| 471 |
+
sizes = [
|
| 472 |
+
self.find_symbol(node, "size", dim=i) or node.get_size()[i]
|
| 473 |
+
for i in range(start_index, end_index + 1)
|
| 474 |
+
]
|
| 475 |
+
if len(sizes) == 0:
|
| 476 |
+
return str(default_value)
|
| 477 |
+
|
| 478 |
+
sizes = [symbols(v) if isinstance(v, str) else v for v in sizes]
|
| 479 |
+
val = sympy_product(sizes)
|
| 480 |
+
return val
|
| 481 |
+
|
| 482 |
+
def stride(self, node: IRNode, index: int, default_value: int = 0) -> str:
|
| 483 |
+
"""
|
| 484 |
+
Hook called from template code to get the stride of an arg.
|
| 485 |
+
Generates code which represents stride of a given node at index.
|
| 486 |
+
If node is None, returns default_value.
|
| 487 |
+
|
| 488 |
+
TODO: Will add needed args to pass it in if it is dynamic.
|
| 489 |
+
"""
|
| 490 |
+
|
| 491 |
+
if node is None:
|
| 492 |
+
return str(default_value)
|
| 493 |
+
|
| 494 |
+
index = _normalize_idx(index, len(node.get_size()))
|
| 495 |
+
if index < 0:
|
| 496 |
+
return str(default_value)
|
| 497 |
+
|
| 498 |
+
stride = node.get_stride()[index]
|
| 499 |
+
if V.graph.sizevars.statically_known_leq(stride, 1):
|
| 500 |
+
return str(stride)
|
| 501 |
+
return self.find_symbol(node, "stride", dim=index) or str(stride)
|
| 502 |
+
|
| 503 |
+
def batch_stride(self, node: IRNode, default_value: int = 0) -> str:
|
| 504 |
+
"""
|
| 505 |
+
Hook called from template code to get the batch stride of an arg.
|
| 506 |
+
Returns 0 if batch dim is not present.
|
| 507 |
+
|
| 508 |
+
This method assumes that batch stride is the largest stride.
|
| 509 |
+
"""
|
| 510 |
+
|
| 511 |
+
if node is None:
|
| 512 |
+
return str(default_value)
|
| 513 |
+
|
| 514 |
+
if len(node.get_size()) < 3:
|
| 515 |
+
return str(default_value)
|
| 516 |
+
|
| 517 |
+
batch_stride = node.get_stride()[0]
|
| 518 |
+
if V.graph.sizevars.statically_known_leq(batch_stride, 1):
|
| 519 |
+
return str(batch_stride)
|
| 520 |
+
|
| 521 |
+
return "{}*{}".format(
|
| 522 |
+
self.find_symbol(node, "size", dim=1) or node.get_size()[1],
|
| 523 |
+
self.find_symbol(node, "size", dim=2) or node.get_size()[2],
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str:
|
| 527 |
+
"""
|
| 528 |
+
Hook called from template code to get the row or column stride of an arg.
|
| 529 |
+
This is required by some CUTLASS 2.X APIs.
|
| 530 |
+
If the node is in row_major, it returns stride[-2].
|
| 531 |
+
If the node is in column_major, it returns stride[-1].
|
| 532 |
+
|
| 533 |
+
TODO: Will add needed args to pass it in if it is dynamic.
|
| 534 |
+
"""
|
| 535 |
+
|
| 536 |
+
if node is None or len(node.get_stride()) < 2:
|
| 537 |
+
return str(default_value)
|
| 538 |
+
|
| 539 |
+
stride0 = node.get_stride()[-1]
|
| 540 |
+
stride1 = node.get_stride()[-2]
|
| 541 |
+
if stride0 == 1:
|
| 542 |
+
return cexpr(self.rename_indexing(stride1))
|
| 543 |
+
elif stride1 == 1:
|
| 544 |
+
return cexpr(self.rename_indexing(stride0))
|
| 545 |
+
else:
|
| 546 |
+
raise RuntimeError(
|
| 547 |
+
f"At least 1 stride should be 1. Strides: {node.get_stride()=}"
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
def load(self, name: str, index: Expr, mode: Any = None) -> CSEVariable:
|
| 551 |
+
"""
|
| 552 |
+
Mock load function for memory planning to optimize allocations properly.
|
| 553 |
+
"""
|
| 554 |
+
return self.create_cse_var(name, bounds=ValueRanges.unknown())
|
| 555 |
+
|
| 556 |
+
def store(self, name: str, index: Expr, value: Any, mode: Any = None) -> None:
|
| 557 |
+
"""
|
| 558 |
+
Mock store function for memory planning to optimize allocations properly.
|
| 559 |
+
"""
|
| 560 |
+
self.store_buffer_names.add(name)
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class CUDATemplateCaller(ChoiceCaller):
|
| 564 |
+
"""
|
| 565 |
+
CUDATemplateCaller
|
| 566 |
+
|
| 567 |
+
This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller.
|
| 568 |
+
Attributes:
|
| 569 |
+
name (str): The name of the caller.
|
| 570 |
+
category (str): The category of the caller.
|
| 571 |
+
bmreq (CUDABenchmarkRequest): The benchmark request for the caller.
|
| 572 |
+
template_buffer (CUDATemplateBuffer): The template buffer for the caller.
|
| 573 |
+
"""
|
| 574 |
+
|
| 575 |
+
def __init__(
|
| 576 |
+
self,
|
| 577 |
+
name: str,
|
| 578 |
+
category: str,
|
| 579 |
+
input_nodes: list[Buffer],
|
| 580 |
+
layout: Layout,
|
| 581 |
+
make_kernel_render: Callable[
|
| 582 |
+
[CUDATemplateBuffer, Optional[list[BaseSchedulerNode]]],
|
| 583 |
+
tuple[CUDATemplateKernel, functools.partial[str]],
|
| 584 |
+
],
|
| 585 |
+
bmreq: CUDABenchmarkRequest,
|
| 586 |
+
supports_epilogue_fusion: bool,
|
| 587 |
+
template: "CUDATemplate", # type: ignore[name-defined]
|
| 588 |
+
info_kwargs: Optional[
|
| 589 |
+
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
|
| 590 |
+
], # type: ignore[type-arg]
|
| 591 |
+
description: str,
|
| 592 |
+
) -> None:
|
| 593 |
+
super().__init__(name, input_nodes, layout, description)
|
| 594 |
+
self.category = category
|
| 595 |
+
self.make_kernel_render = make_kernel_render
|
| 596 |
+
self.bmreq = bmreq
|
| 597 |
+
self.supports_epilogue_fusion = supports_epilogue_fusion
|
| 598 |
+
self.template = template
|
| 599 |
+
self.info_kwargs = info_kwargs
|
| 600 |
+
|
| 601 |
+
def precompile(self) -> None:
|
| 602 |
+
assert self.bmreq is not None
|
| 603 |
+
self.bmreq.precompile()
|
| 604 |
+
|
| 605 |
+
def benchmark(self, *args, out) -> float:
|
| 606 |
+
assert self.bmreq is not None
|
| 607 |
+
if config.profile_bandwidth_with_do_bench_using_profiling:
|
| 608 |
+
algo = self.bmreq.make_run_fn(*args, out=out)
|
| 609 |
+
return do_bench_using_profiling(algo)
|
| 610 |
+
return self.bmreq.benchmark(*args, out=out)
|
| 611 |
+
|
| 612 |
+
def __str__(self) -> str:
|
| 613 |
+
return f"CUDATemplateCaller(source_file={self.bmreq.source_file})"
|
| 614 |
+
|
| 615 |
+
def call_name(self) -> str:
|
| 616 |
+
return f"cuda_template_kernels.{self.name}"
|
| 617 |
+
|
| 618 |
+
def kernel_hash_key(self) -> str:
|
| 619 |
+
"""
|
| 620 |
+
Return kernel hash key that does not depend on swizzle.
|
| 621 |
+
"""
|
| 622 |
+
return "-".join(
|
| 623 |
+
[
|
| 624 |
+
self.category,
|
| 625 |
+
self.bmreq.hash_key,
|
| 626 |
+
]
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
def hash_key(self) -> str:
|
| 630 |
+
"""
|
| 631 |
+
Return kernel hash key that does not depend on swizzle.
|
| 632 |
+
"""
|
| 633 |
+
return "-".join(
|
| 634 |
+
[
|
| 635 |
+
self.category,
|
| 636 |
+
self.bmreq.hash_key,
|
| 637 |
+
str(self.info_dict().get("swizzle")),
|
| 638 |
+
]
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]:
|
| 642 |
+
"""Information returned here is logged to the autotune log file when that is enabled."""
|
| 643 |
+
if self.info_kwargs is not None and "op" in self.info_kwargs:
|
| 644 |
+
op: Any = self.info_kwargs["op"]
|
| 645 |
+
return {
|
| 646 |
+
"backend": "CUDA",
|
| 647 |
+
"op_type": type(op).__name__,
|
| 648 |
+
"op_conf_name": str(op.configuration_name()),
|
| 649 |
+
"op_arch": str(op.arch),
|
| 650 |
+
"tile_shape": str(op.tile_description.tile_shape),
|
| 651 |
+
"epilogue_schedule": str(op.epilogue_schedule),
|
| 652 |
+
"kernel_schedule": str(op.kernel_schedule),
|
| 653 |
+
"element_accumulator": str(op.accumulator_type()),
|
| 654 |
+
"op_name": str(op.procedural_name()),
|
| 655 |
+
"instruction_shape": str(
|
| 656 |
+
op.tile_description.math_instruction.instruction_shape
|
| 657 |
+
),
|
| 658 |
+
"swizzle": str(self.info_kwargs["swizzle"]),
|
| 659 |
+
}
|
| 660 |
+
else:
|
| 661 |
+
return {"backend": "CUDA", "op_type": "unknown"}
|
| 662 |
+
|
| 663 |
+
def output_node(self) -> TensorBox:
|
| 664 |
+
self.bmreq.update_workspace_size()
|
| 665 |
+
return TensorBox.create(
|
| 666 |
+
CUDATemplateBuffer(
|
| 667 |
+
layout=self.layout,
|
| 668 |
+
inputs=self.input_nodes,
|
| 669 |
+
make_kernel_render=self.make_kernel_render,
|
| 670 |
+
workspace_size=self.bmreq.workspace_size,
|
| 671 |
+
supports_epilogue_fusion=self.supports_epilogue_fusion,
|
| 672 |
+
template=self.template,
|
| 673 |
+
)
|
| 674 |
+
)
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_template.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
import hashlib
|
| 4 |
+
import itertools
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any, Optional, TYPE_CHECKING
|
| 7 |
+
from typing_extensions import override
|
| 8 |
+
from unittest.mock import patch
|
| 9 |
+
|
| 10 |
+
import sympy
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch._inductor.utils import Placeholder
|
| 14 |
+
from torch._logging import getArtifactLogger
|
| 15 |
+
|
| 16 |
+
from ...autotune_process import CUDABenchmarkRequest, TensorMeta
|
| 17 |
+
from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout
|
| 18 |
+
from ...utils import IndentedBuffer, unique
|
| 19 |
+
from ...virtualized import V
|
| 20 |
+
from ..common import KernelTemplate
|
| 21 |
+
from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel
|
| 22 |
+
from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if TYPE_CHECKING:
|
| 26 |
+
from ...scheduler import BaseSchedulerNode # noqa: TC004
|
| 27 |
+
else:
|
| 28 |
+
BaseSchedulerNode = Any
|
| 29 |
+
|
| 30 |
+
GemmOperation = Any
|
| 31 |
+
|
| 32 |
+
autotuning_log = getArtifactLogger(__name__, "autotuning")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass(frozen=True)
|
| 36 |
+
class ArgInfo:
|
| 37 |
+
name: str
|
| 38 |
+
ty: str
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class CUDATemplate(KernelTemplate):
|
| 42 |
+
index_counter = itertools.count()
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
name: str,
|
| 47 |
+
input_nodes: list[Buffer],
|
| 48 |
+
layout: Layout,
|
| 49 |
+
input_reorder: Optional[list[int]] = None,
|
| 50 |
+
) -> None:
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
name (str): The name of the CUDATemplate object.
|
| 57 |
+
input_nodes (List[IRNode]): A list of input IRNodes.
|
| 58 |
+
layout (Layout): The layout of the output buffer / tensor.
|
| 59 |
+
input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes.
|
| 60 |
+
|
| 61 |
+
"""
|
| 62 |
+
super().__init__(name)
|
| 63 |
+
self.input_nodes = input_nodes
|
| 64 |
+
self.output_node: Buffer = Buffer(name="buf_out", layout=layout)
|
| 65 |
+
self.input_reorder = input_reorder
|
| 66 |
+
self.layout = layout
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def supports_epilogue_fusion(op: GemmOperation) -> bool:
|
| 70 |
+
return False
|
| 71 |
+
|
| 72 |
+
def generate( # type: ignore[override]
|
| 73 |
+
self,
|
| 74 |
+
description,
|
| 75 |
+
**kwargs,
|
| 76 |
+
) -> CUDATemplateCaller:
|
| 77 |
+
"""
|
| 78 |
+
Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller
|
| 79 |
+
may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
kwargs: Additional keyword arguments.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
A CUDATemplateCaller object representing the generated CUDA template caller.
|
| 86 |
+
"""
|
| 87 |
+
kernel_name = str(Placeholder.KERNEL_NAME)
|
| 88 |
+
with (
|
| 89 |
+
patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
|
| 90 |
+
CUDATemplateKernel(
|
| 91 |
+
kernel_name=kernel_name,
|
| 92 |
+
runtime_arg_info=self.get_runtime_arg_info(),
|
| 93 |
+
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
|
| 94 |
+
) as kernel,
|
| 95 |
+
):
|
| 96 |
+
code = self.render(kernel=kernel, **kwargs)
|
| 97 |
+
_, call_args, _, _ = kernel.args.python_argdefs()
|
| 98 |
+
autotuning_log.debug("Generated Code:\n%s", code)
|
| 99 |
+
autotuning_log.debug(
|
| 100 |
+
"Args: cpp_argdefs: %s, python_argdefs: %s",
|
| 101 |
+
kernel.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE),
|
| 102 |
+
kernel.args.python_argdefs(),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
input_reorder = (
|
| 106 |
+
self.input_reorder
|
| 107 |
+
if self.input_reorder is not None
|
| 108 |
+
else list(range(len(self.input_nodes)))
|
| 109 |
+
)
|
| 110 |
+
expected_args = list(
|
| 111 |
+
unique(self.input_nodes[idx].get_name() for idx in input_reorder)
|
| 112 |
+
)
|
| 113 |
+
expected_args.extend([self.output_node.get_name()])
|
| 114 |
+
assert list(call_args)[: len(expected_args)] == expected_args, (
|
| 115 |
+
call_args,
|
| 116 |
+
expected_args,
|
| 117 |
+
)
|
| 118 |
+
V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :]))
|
| 119 |
+
size_args = V.graph.sizevars.size_hints(kernel.get_dynamic_shape_args())
|
| 120 |
+
extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs))
|
| 121 |
+
|
| 122 |
+
kernel_hash = hashlib.sha256(code.encode("utf-8")).hexdigest()[:8]
|
| 123 |
+
kernel_name = f"cutlass_{kernel_hash}"
|
| 124 |
+
code = code.replace(self.name, kernel_name)
|
| 125 |
+
|
| 126 |
+
# create the BenchmarkRequest
|
| 127 |
+
bmreq = CUDABenchmarkRequest(
|
| 128 |
+
kernel_name=kernel_name,
|
| 129 |
+
input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
|
| 130 |
+
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
|
| 131 |
+
extra_args=extra_args,
|
| 132 |
+
source_code=code,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# kwargs has "op" argument in case of CUTLASSGemmTemplate
|
| 136 |
+
op = kwargs["op"]
|
| 137 |
+
if not op:
|
| 138 |
+
supports_epilogue_fusion = False
|
| 139 |
+
else:
|
| 140 |
+
# epilogue fusion is only supported for TMA kernels
|
| 141 |
+
supports_epilogue_fusion = self.supports_epilogue_fusion(op)
|
| 142 |
+
|
| 143 |
+
def make_kernel_render(
|
| 144 |
+
template_node: CUDATemplateBuffer,
|
| 145 |
+
epilogue_nodes: Optional[list[BaseSchedulerNode]] = None,
|
| 146 |
+
) -> tuple[CUDATemplateKernel, functools.partial[str]]:
|
| 147 |
+
assert supports_epilogue_fusion or not epilogue_nodes, (
|
| 148 |
+
"epilogue fusion is not supported for this kernel"
|
| 149 |
+
)
|
| 150 |
+
kernel = CUDATemplateKernel(
|
| 151 |
+
kernel_name=str(Placeholder.KERNEL_NAME),
|
| 152 |
+
runtime_arg_info=self.get_runtime_arg_info(),
|
| 153 |
+
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
|
| 154 |
+
)
|
| 155 |
+
render = functools.partial(
|
| 156 |
+
self.render,
|
| 157 |
+
kernel=kernel,
|
| 158 |
+
template_buffer_node=template_node,
|
| 159 |
+
epilogue_nodes=epilogue_nodes,
|
| 160 |
+
**kwargs, # includes "op" argument in case of CUTLASSGemmTemplate
|
| 161 |
+
)
|
| 162 |
+
return kernel, render
|
| 163 |
+
|
| 164 |
+
return CUDATemplateCaller(
|
| 165 |
+
kernel_name,
|
| 166 |
+
"cutlass_gemm",
|
| 167 |
+
self.input_nodes,
|
| 168 |
+
self.output_node.get_layout(),
|
| 169 |
+
make_kernel_render,
|
| 170 |
+
bmreq,
|
| 171 |
+
supports_epilogue_fusion,
|
| 172 |
+
self,
|
| 173 |
+
kwargs,
|
| 174 |
+
description,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def header(self) -> IndentedBuffer:
|
| 178 |
+
res = IndentedBuffer()
|
| 179 |
+
res.splice(
|
| 180 |
+
"""
|
| 181 |
+
#include <exception>
|
| 182 |
+
#include <iostream>
|
| 183 |
+
#include <memory>
|
| 184 |
+
#include <random>
|
| 185 |
+
#include <vector>
|
| 186 |
+
"""
|
| 187 |
+
)
|
| 188 |
+
return res
|
| 189 |
+
|
| 190 |
+
def globals(self) -> IndentedBuffer:
|
| 191 |
+
res = IndentedBuffer()
|
| 192 |
+
res.splice(
|
| 193 |
+
"""
|
| 194 |
+
// We compile all models with -fvisibility=hidden. Any symbols that need to be
|
| 195 |
+
// exposed in the final shared library must be declared with PT_EXPORT to make
|
| 196 |
+
// them visible.
|
| 197 |
+
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
|
| 198 |
+
#define PT_EXPORT __attribute__((__visibility__("default")))
|
| 199 |
+
#else
|
| 200 |
+
#ifdef _WIN32
|
| 201 |
+
#define PT_EXPORT __declspec(dllexport)
|
| 202 |
+
#else
|
| 203 |
+
#define PT_EXPORT
|
| 204 |
+
#endif
|
| 205 |
+
#endif
|
| 206 |
+
"""
|
| 207 |
+
)
|
| 208 |
+
return res
|
| 209 |
+
|
| 210 |
+
def render(self, **kwargs) -> str:
|
| 211 |
+
raise NotImplementedError
|
| 212 |
+
|
| 213 |
+
def get_runtime_arg_info(self) -> list[ArgInfo]:
|
| 214 |
+
return []
|
| 215 |
+
|
| 216 |
+
def get_runtime_arg_values(self, **kwargs) -> list[Any]:
|
| 217 |
+
return []
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class CUTLASSTemplate(CUDATemplate):
|
| 221 |
+
"""
|
| 222 |
+
CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the
|
| 223 |
+
CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
def header(self) -> IndentedBuffer:
|
| 227 |
+
res = super().header()
|
| 228 |
+
res.splice(
|
| 229 |
+
"""
|
| 230 |
+
#include "cute/tensor.hpp"
|
| 231 |
+
#include "cutlass/cutlass.h"
|
| 232 |
+
#include "cutlass/numeric_types.h"
|
| 233 |
+
#include "cutlass/tensor_ref.h"
|
| 234 |
+
#include "cutlass/util/host_tensor.h"
|
| 235 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 236 |
+
#include "cutlass/util/reference/device/tensor_fill.h"
|
| 237 |
+
#include "cutlass/util/device_memory.h"
|
| 238 |
+
"""
|
| 239 |
+
)
|
| 240 |
+
return res
|
| 241 |
+
|
| 242 |
+
def globals(self) -> IndentedBuffer:
|
| 243 |
+
res = super().globals()
|
| 244 |
+
res.splice(
|
| 245 |
+
"""
|
| 246 |
+
using namespace cute;
|
| 247 |
+
#define CUTLASS_CHECK(status) \\
|
| 248 |
+
{ \\
|
| 249 |
+
cutlass::Status error = status; \\
|
| 250 |
+
if (error != cutlass::Status::kSuccess) { \\
|
| 251 |
+
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\
|
| 252 |
+
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\
|
| 253 |
+
throw std::runtime_error(msg); \\
|
| 254 |
+
} \\
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
// Used as pass-through functor in EVT just for type casting / rounding
|
| 258 |
+
template <typename T>
|
| 259 |
+
struct identity_op {
|
| 260 |
+
CUTLASS_HOST_DEVICE
|
| 261 |
+
T operator()(T val) const { return val; }
|
| 262 |
+
};
|
| 263 |
+
|
| 264 |
+
"""
|
| 265 |
+
)
|
| 266 |
+
return res
|
| 267 |
+
|
| 268 |
+
def cute_int(self, int_str: str, var_name: str) -> str:
|
| 269 |
+
res = ""
|
| 270 |
+
if int_str in ("1", "1L"):
|
| 271 |
+
res = "cute::Int<1>{}"
|
| 272 |
+
else:
|
| 273 |
+
res = int_str
|
| 274 |
+
|
| 275 |
+
return f"{res} /* {var_name} */"
|
| 276 |
+
|
| 277 |
+
_DTYPE_TO_CUTLASS = {
|
| 278 |
+
torch.float32: "float",
|
| 279 |
+
torch.float64: "double",
|
| 280 |
+
torch.float16: "cutlass::half_t",
|
| 281 |
+
torch.int32: "int32_t",
|
| 282 |
+
torch.int16: "int16_t",
|
| 283 |
+
torch.int8: "int8_t",
|
| 284 |
+
torch.uint8: "uint8_t",
|
| 285 |
+
torch.bool: "bool",
|
| 286 |
+
torch.bfloat16: "cutlass::bfloat16_t",
|
| 287 |
+
torch.float8_e4m3fn: "cutlass::float_e4m3_t",
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
_DTYPE_TO_CUTLASS_SPARSE_META = {
|
| 291 |
+
torch.int32: "uint32_t",
|
| 292 |
+
torch.int16: "uint16_t",
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
def cutlass_type_cast(self, node: IRNode, ptr: str) -> str:
|
| 296 |
+
if node is None:
|
| 297 |
+
return ptr
|
| 298 |
+
else:
|
| 299 |
+
return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})"
|
| 300 |
+
|
| 301 |
+
def cutlass_sparse_meta_type_cast(self, node: IRNode, ptr: str) -> str:
|
| 302 |
+
if node is None:
|
| 303 |
+
return ptr
|
| 304 |
+
else:
|
| 305 |
+
return (
|
| 306 |
+
f"({self._DTYPE_TO_CUTLASS_SPARSE_META.get(node.get_dtype())}*)({ptr})"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
@override
|
| 310 |
+
def get_runtime_arg_info(self) -> list[ArgInfo]:
|
| 311 |
+
return [ArgInfo("swizzle", "const uint8_t")]
|
| 312 |
+
|
| 313 |
+
@override
|
| 314 |
+
def get_runtime_arg_values(self, **kwargs) -> list[Any]:
|
| 315 |
+
"""
|
| 316 |
+
Helper method to retrieve runtime args from generate kwargs
|
| 317 |
+
"""
|
| 318 |
+
return [kwargs[arg.name] for arg in self.get_runtime_arg_info()]
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_cache.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
import hashlib
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from typing import Any, Optional
|
| 9 |
+
|
| 10 |
+
import torch._inductor.config as config
|
| 11 |
+
from torch._inductor.codecache import cutlass_key
|
| 12 |
+
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version
|
| 13 |
+
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
|
| 14 |
+
from torch._inductor.runtime.cache_dir_utils import cache_dir
|
| 15 |
+
from torch._inductor.utils import clear_on_fresh_cache
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
log = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
CONFIG_PREFIX: str = "configs"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_config_request_key(
|
| 25 |
+
arch: str,
|
| 26 |
+
cuda_version: str,
|
| 27 |
+
instantiation_level: str,
|
| 28 |
+
) -> str:
|
| 29 |
+
"""
|
| 30 |
+
Return a key for the full ops, based on cutlass key, arch, cuda version, and instantiation level.
|
| 31 |
+
"""
|
| 32 |
+
hash_target = "-".join(
|
| 33 |
+
[
|
| 34 |
+
cutlass_key().hex(),
|
| 35 |
+
arch,
|
| 36 |
+
cuda_version,
|
| 37 |
+
instantiation_level,
|
| 38 |
+
]
|
| 39 |
+
)
|
| 40 |
+
return hashlib.sha256(hash_target.encode("utf-8")).hexdigest()[0:8]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _generate_config_filename(request_key: str) -> str:
|
| 44 |
+
"""
|
| 45 |
+
Generate a filename for the full ops.
|
| 46 |
+
"""
|
| 47 |
+
return f"{CONFIG_PREFIX}_{request_key}.json"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@clear_on_fresh_cache
|
| 51 |
+
@functools.cache
|
| 52 |
+
def maybe_fetch_ops() -> Optional[list[Any]]:
|
| 53 |
+
"""
|
| 54 |
+
Fetch ops from databases.
|
| 55 |
+
"""
|
| 56 |
+
if config.force_disable_caches:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
# setup
|
| 60 |
+
arch: str = get_cuda_arch()
|
| 61 |
+
# get_cuda_version might return "12.4.0" or "12.4"
|
| 62 |
+
# but we want to use "12.4"
|
| 63 |
+
version: str = ".".join(get_cuda_version().split(".")[:2])
|
| 64 |
+
instantiation_level: str = config.cuda.cutlass_instantiation_level
|
| 65 |
+
|
| 66 |
+
# filename and filepath
|
| 67 |
+
request_key: str = get_config_request_key(arch, version, instantiation_level)
|
| 68 |
+
filename: str = _generate_config_filename(request_key)
|
| 69 |
+
filepath: str = os.path.join(cache_dir(), filename)
|
| 70 |
+
|
| 71 |
+
# try fetch
|
| 72 |
+
serialized_ops: Optional[list[str]] = None
|
| 73 |
+
start_time = time.time()
|
| 74 |
+
if os.path.isfile(filepath):
|
| 75 |
+
# locally
|
| 76 |
+
try:
|
| 77 |
+
with open(filepath) as f:
|
| 78 |
+
serialized_ops = json.load(f)
|
| 79 |
+
|
| 80 |
+
assert isinstance(serialized_ops, list), (
|
| 81 |
+
f"Expected serialized ops is a list, got {type(serialized_ops)}"
|
| 82 |
+
)
|
| 83 |
+
except Exception as e:
|
| 84 |
+
log.warning(
|
| 85 |
+
"Failed to load CUTLASS config %s from local cache: %s",
|
| 86 |
+
filename,
|
| 87 |
+
e,
|
| 88 |
+
)
|
| 89 |
+
serialized_ops = None
|
| 90 |
+
elif config.is_fbcode():
|
| 91 |
+
from torch._inductor.fb.cutlass_remote_cache import (
|
| 92 |
+
maybe_fetch_cutlass_configs_from_remote,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# from remote
|
| 96 |
+
serialized_ops = maybe_fetch_cutlass_configs_from_remote(filepath)
|
| 97 |
+
|
| 98 |
+
if serialized_ops is None:
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
# deserialize
|
| 102 |
+
serializer = get_cutlass_operation_serializer()
|
| 103 |
+
full_ops = [serializer.deserialize(x) for x in serialized_ops] # type: ignore[union-attr]
|
| 104 |
+
log.info("Loaded ops from %s cache in %.3fs", filename, time.time() - start_time)
|
| 105 |
+
return full_ops
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Union
|
| 2 |
+
|
| 3 |
+
from sympy import Expr
|
| 4 |
+
|
| 5 |
+
from torch._inductor.ir import (
|
| 6 |
+
ComputedBuffer,
|
| 7 |
+
InputBuffer,
|
| 8 |
+
is_contiguous_strides_for_shape,
|
| 9 |
+
)
|
| 10 |
+
from torch.utils._ordered_set import OrderedSet
|
| 11 |
+
|
| 12 |
+
from ..cutlass_utils import torch_dtype_to_cutlass_type, try_import_cutlass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
EpilogueFunctor = Any # EpilogueFunctor local class defined in _trace
|
| 16 |
+
Buffer = Union[ComputedBuffer, InputBuffer]
|
| 17 |
+
CutlassTupleType = Any # cutlass.backend.c_types.tuple_factory_.<locals>.TupleType
|
| 18 |
+
CutlassVisitorType = Any # cutlass.backend.c_types.visitor_factory.<locals>.VisitorType
|
| 19 |
+
CutlassArgType = (
|
| 20 |
+
Any # Can be a CutlassTupleType, CutlassVisitorType, EmptyByte, or ctype.c_void_p
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if try_import_cutlass():
|
| 25 |
+
import ast
|
| 26 |
+
import ctypes
|
| 27 |
+
import textwrap
|
| 28 |
+
from typing import Union
|
| 29 |
+
|
| 30 |
+
from cutlass.backend.c_types import ( # type: ignore[import-untyped, import-not-found]
|
| 31 |
+
EmptyByte,
|
| 32 |
+
)
|
| 33 |
+
from cutlass.backend.epilogue import ( # type: ignore[import-untyped, import-not-found]
|
| 34 |
+
dtype2ctype,
|
| 35 |
+
)
|
| 36 |
+
from cutlass.backend.evt import ( # type: ignore[import-untyped, import-not-found]
|
| 37 |
+
EpilogueFunctorVisitor,
|
| 38 |
+
)
|
| 39 |
+
from cutlass.backend.evt.backend.emitter_base import ( # type: ignore[import-untyped, import-not-found]
|
| 40 |
+
FusionCallbacks,
|
| 41 |
+
)
|
| 42 |
+
from cutlass.backend.evt.backend.sm90_emitter import ( # type: ignore[import-untyped, import-not-found]
|
| 43 |
+
CollectiveEpilogue,
|
| 44 |
+
)
|
| 45 |
+
from cutlass.backend.evt.frontend import ( # type: ignore[import-untyped, import-not-found]
|
| 46 |
+
PythonASTFrontend,
|
| 47 |
+
)
|
| 48 |
+
from cutlass.backend.evt.ir.tensor import ( # type: ignore[import-untyped, import-not-found]
|
| 49 |
+
Tensor as CutlassTensor,
|
| 50 |
+
)
|
| 51 |
+
from cutlass_library import (
|
| 52 |
+
DataType,
|
| 53 |
+
EpilogueScheduleType,
|
| 54 |
+
LayoutType,
|
| 55 |
+
TileDescription,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
from torch._inductor.codegen.cuda import cuda_env
|
| 59 |
+
from torch._inductor.utils import IndentedBuffer
|
| 60 |
+
|
| 61 |
+
_CUTLASS_C_DTYPES = OrderedSet(dtype2ctype.values()) # type: ignore[var-annotated]
|
| 62 |
+
|
| 63 |
+
def create_example_tensors(
|
| 64 |
+
var_name_to_buffer_name: dict[str, str],
|
| 65 |
+
name_to_buffer: dict[str, Buffer],
|
| 66 |
+
size_hint_fn: Callable[[Union[Expr, int]], int],
|
| 67 |
+
) -> dict[str, CutlassTensor]:
|
| 68 |
+
def cutlass_tensor_from_buffer(buffer: Buffer) -> CutlassTensor:
|
| 69 |
+
shape = buffer.get_layout().size
|
| 70 |
+
stride = buffer.get_layout().stride
|
| 71 |
+
shape = tuple(size_hint_fn(x) for x in shape)
|
| 72 |
+
stride = tuple(size_hint_fn(x) for x in stride)
|
| 73 |
+
|
| 74 |
+
is_row_major = is_contiguous_strides_for_shape(stride, shape)
|
| 75 |
+
is_column_major = is_contiguous_strides_for_shape(stride[::-1], shape[::-1])
|
| 76 |
+
|
| 77 |
+
if not is_row_major and not is_column_major:
|
| 78 |
+
raise RuntimeError(
|
| 79 |
+
f"Cannot create example tensor for {buffer.get_name()} with \
|
| 80 |
+
non-contiguous layout, received stride: {stride} and shape: {shape}"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
return CutlassTensor(
|
| 84 |
+
shape=shape,
|
| 85 |
+
layout_tag=LayoutType.RowMajor
|
| 86 |
+
if is_row_major
|
| 87 |
+
else LayoutType.ColumnMajor,
|
| 88 |
+
element=torch_dtype_to_cutlass_type(buffer.get_layout().dtype),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return {
|
| 92 |
+
key: cutlass_tensor_from_buffer(name_to_buffer[name])
|
| 93 |
+
for key, name in var_name_to_buffer_name.items()
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
def trace(
|
| 97 |
+
fn_src: str,
|
| 98 |
+
example_tensors: dict[str, CutlassTensor],
|
| 99 |
+
accum_type: DataType,
|
| 100 |
+
output_type: DataType,
|
| 101 |
+
tile_description: TileDescription,
|
| 102 |
+
epilogue_schedule: EpilogueScheduleType,
|
| 103 |
+
name_to_buffer: dict[str, Buffer],
|
| 104 |
+
size_hint_fn: Callable[[Union[Expr, int]], int],
|
| 105 |
+
**kwargs: dict[str, Any],
|
| 106 |
+
) -> tuple[str, str, str]:
|
| 107 |
+
cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type]
|
| 108 |
+
assert cuda_arch >= 90, "Only SM90+ is supported for EVT"
|
| 109 |
+
epilogue_functor = _trace(fn_src, example_tensors, cuda_arch, **kwargs)
|
| 110 |
+
visitor = EpilogueFunctorVisitor(cuda_arch, epilogue_functor)
|
| 111 |
+
fusion_callbacks = FusionCallbacks(visitor.graph, cuda_arch, emit_CD=False)
|
| 112 |
+
collective_epilogue = CollectiveEpilogue(
|
| 113 |
+
tile_description,
|
| 114 |
+
epilogue_schedule,
|
| 115 |
+
accum_type,
|
| 116 |
+
output_type,
|
| 117 |
+
fusion_callbacks,
|
| 118 |
+
)
|
| 119 |
+
evt_name, evt_code = collective_epilogue.emit()
|
| 120 |
+
evt_args = _render_argument_type(epilogue_functor, name_to_buffer, size_hint_fn)
|
| 121 |
+
return evt_name, evt_args, evt_code
|
| 122 |
+
|
| 123 |
+
# Based off of
|
| 124 |
+
# https://github.com/NVIDIA/cutlass/blob/df18f5e4f5de76bed8be1de8e4c245f2f5ec3020/python/cutlass/epilogue/epilogue.py#L117
|
| 125 |
+
# This is modified to enable directly passing the source code of the epilogue vs getting it from a bona-fide python function
|
| 126 |
+
# The reason for this is that inspect.getsource does not work with functions defined at runtime via exec/eval
|
| 127 |
+
def _trace(
|
| 128 |
+
fn_src: str, example_tensors: dict[str, CutlassTensor], cc: int, **kwargs: Any
|
| 129 |
+
) -> EpilogueFunctor:
|
| 130 |
+
class EpilogueFunctor(PythonASTFrontend):
|
| 131 |
+
def __init__(self, cc: int, **kwargs: Any):
|
| 132 |
+
self.source = textwrap.dedent(fn_src)
|
| 133 |
+
super().__init__(cc, **kwargs)
|
| 134 |
+
|
| 135 |
+
def parse(self, example_inputs: dict[str, CutlassTensor]) -> None:
|
| 136 |
+
self.example_inputs = example_inputs
|
| 137 |
+
self.ast = ast.parse(self.source)
|
| 138 |
+
self.visit(self.ast)
|
| 139 |
+
|
| 140 |
+
cc = int(cuda_env.get_cuda_arch())
|
| 141 |
+
epilogue_functor = EpilogueFunctor(cc=cc, **kwargs)
|
| 142 |
+
epilogue_functor.trace(example_tensors)
|
| 143 |
+
return epilogue_functor
|
| 144 |
+
|
| 145 |
+
def _render_argument_type(
|
| 146 |
+
epilogue_functor: EpilogueFunctor,
|
| 147 |
+
name_to_buffer: dict[str, Buffer],
|
| 148 |
+
size_hint_fn: Callable[[Union[Expr, int]], int],
|
| 149 |
+
) -> str:
|
| 150 |
+
epilogue_thread_type = epilogue_functor.epilogue_thread_type
|
| 151 |
+
|
| 152 |
+
# Fragile, but this is the only way to guarantee t is expected type because t is a local class
|
| 153 |
+
def is_nested_visitor_type(t: type) -> bool:
|
| 154 |
+
return (
|
| 155 |
+
".".join([t.__module__, t.__qualname__])
|
| 156 |
+
== "cutlass.backend.c_types.visitor_factory.<locals>.VisitorType"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
buffer = IndentedBuffer()
|
| 160 |
+
with buffer.set_tabwidth(2):
|
| 161 |
+
|
| 162 |
+
def render_argument_type(name: str, t: CutlassArgType) -> None:
|
| 163 |
+
if issubclass(t, ctypes.c_byte):
|
| 164 |
+
buffer.writeline(f"{{}}, /* {name} */")
|
| 165 |
+
else:
|
| 166 |
+
fields = [
|
| 167 |
+
(
|
| 168 |
+
fname,
|
| 169 |
+
_get_arg_from_node(ty, name_to_buffer[name], size_hint_fn),
|
| 170 |
+
)
|
| 171 |
+
for fname, ty in t._fields_
|
| 172 |
+
]
|
| 173 |
+
field_strs = [
|
| 174 |
+
f"/* {fname} */ {str(field)}" for fname, field in fields
|
| 175 |
+
]
|
| 176 |
+
buffer.writeline(f"{{{', '.join(field_strs)}}}, /* {name} */")
|
| 177 |
+
|
| 178 |
+
def render_thread_type(name: str, t: CutlassArgType) -> None:
|
| 179 |
+
if is_nested_visitor_type(t):
|
| 180 |
+
buffer.writeline(f"{{ /* {name} */")
|
| 181 |
+
with buffer.indent():
|
| 182 |
+
for name, inner_t in t._fields_:
|
| 183 |
+
render_thread_type(name, inner_t)
|
| 184 |
+
buffer.writeline("},")
|
| 185 |
+
else:
|
| 186 |
+
render_argument_type(name, t)
|
| 187 |
+
|
| 188 |
+
# unroll the recursion once to address special case formatting
|
| 189 |
+
# namely, no ending comma and no indentation for the outermost thread type
|
| 190 |
+
buffer.writeline("{ /* thread */")
|
| 191 |
+
with buffer.indent(3):
|
| 192 |
+
if is_nested_visitor_type(epilogue_thread_type):
|
| 193 |
+
with buffer.indent():
|
| 194 |
+
for name, inner_t in epilogue_thread_type._fields_:
|
| 195 |
+
render_thread_type(name, inner_t)
|
| 196 |
+
else:
|
| 197 |
+
render_argument_type("thread", epilogue_thread_type)
|
| 198 |
+
buffer.writeline("}")
|
| 199 |
+
|
| 200 |
+
return buffer.getvalue()
|
| 201 |
+
|
| 202 |
+
def _get_arg_from_node(
|
| 203 |
+
arg_ty: type, node: Buffer, size_hint_fn: Callable[[Union[Expr, int]], int]
|
| 204 |
+
) -> str:
|
| 205 |
+
from ..cuda_template import CUTLASSTemplate
|
| 206 |
+
|
| 207 |
+
# Today, arguments are either a pointer to the
|
| 208 |
+
# node's memory, a stride tuple, the datatype
|
| 209 |
+
# Once again, need to check for local class type for stride tuple
|
| 210 |
+
if (
|
| 211 |
+
str(arg_ty)
|
| 212 |
+
== "<class 'cutlass.backend.c_types.tuple_factory_.<locals>.TupleType'>"
|
| 213 |
+
):
|
| 214 |
+
DEFAULT_STRIDE_LEN = 3
|
| 215 |
+
assert len(node.get_layout().stride) <= DEFAULT_STRIDE_LEN
|
| 216 |
+
stride = [size_hint_fn(x) for x in node.get_layout().stride]
|
| 217 |
+
for _ in range(DEFAULT_STRIDE_LEN - len(stride)):
|
| 218 |
+
stride.append(0)
|
| 219 |
+
|
| 220 |
+
def render_stride(x: int) -> str:
|
| 221 |
+
# Handle EBO for 0 and 1
|
| 222 |
+
if x == 0:
|
| 223 |
+
return "_0{}"
|
| 224 |
+
elif x == 1:
|
| 225 |
+
return "_1{}"
|
| 226 |
+
else:
|
| 227 |
+
return str(x)
|
| 228 |
+
|
| 229 |
+
return f"{{{', '.join([render_stride(x) for x in stride])}}}"
|
| 230 |
+
|
| 231 |
+
elif issubclass(arg_ty, ctypes.c_void_p):
|
| 232 |
+
return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) {node.get_name()}"
|
| 233 |
+
elif (
|
| 234 |
+
arg_ty in _CUTLASS_C_DTYPES
|
| 235 |
+
): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently
|
| 236 |
+
return f"{CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}(0)"
|
| 237 |
+
elif issubclass(arg_ty, EmptyByte):
|
| 238 |
+
return "{}"
|
| 239 |
+
|
| 240 |
+
raise NotImplementedError(f"Unsupported arg type: {arg_ty}")
|
.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
from ..cutlass_utils import try_import_cutlass
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# copied / modified from original at
|
| 6 |
+
# https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658
|
| 7 |
+
|
| 8 |
+
if try_import_cutlass():
|
| 9 |
+
import enum
|
| 10 |
+
|
| 11 |
+
from cutlass_library.gemm_operation import * # noqa: F401, F403
|
| 12 |
+
from cutlass_library.library import * # noqa: F401, F403
|
| 13 |
+
|
| 14 |
+
_LOGGER = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
class EmitGemmUniversal3xInstanceWithEVT:
|
| 17 |
+
"""Responsible for emitting a CUTLASS 3.x template definition"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, operation_suffix="", evt_name=None):
|
| 20 |
+
self.operation_suffix = operation_suffix
|
| 21 |
+
self.includes = [
|
| 22 |
+
"cutlass/cutlass.h",
|
| 23 |
+
"cutlass/gemm/gemm.h",
|
| 24 |
+
"cutlass/numeric_types.h",
|
| 25 |
+
"cutlass/gemm/kernel/gemm_universal.hpp",
|
| 26 |
+
"cutlass/gemm/collective/collective_builder.hpp",
|
| 27 |
+
"cutlass/epilogue/collective/collective_builder.hpp",
|
| 28 |
+
]
|
| 29 |
+
self.builtin_epilogue_functor_template = """${epilogue_functor}<
|
| 30 |
+
${element_d},
|
| 31 |
+
${element_epilogue},
|
| 32 |
+
${element_c},
|
| 33 |
+
${element_epilogue}
|
| 34 |
+
>"""
|
| 35 |
+
|
| 36 |
+
self.evt_name = evt_name
|
| 37 |
+
self.gemm_template = """
|
| 38 |
+
using ${operation_name}_epilogue =
|
| 39 |
+
typename cutlass::epilogue::collective::CollectiveBuilder<
|
| 40 |
+
${arch}, ${opcode_class_epi},
|
| 41 |
+
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
|
| 42 |
+
cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
|
| 43 |
+
${epi_tile_mn},
|
| 44 |
+
${element_accumulator}, ${element_epilogue},
|
| 45 |
+
${element_c}, ${layout_c}, ${align_c},
|
| 46 |
+
${element_d}, ${layout_d}, ${align_d},
|
| 47 |
+
${epilogue_schedule},
|
| 48 |
+
${epilogue_functor}
|
| 49 |
+
>::CollectiveOp;
|
| 50 |
+
|
| 51 |
+
${mixed_dtype_prepare_code}
|
| 52 |
+
|
| 53 |
+
using ${operation_name}_mainloop =
|
| 54 |
+
typename cutlass::gemm::collective::CollectiveBuilder<
|
| 55 |
+
${arch}, ${opcode_class_main},
|
| 56 |
+
${element_a}, ${layout_a}, ${align_a},
|
| 57 |
+
${element_b}, ${layout_b}, ${align_b},
|
| 58 |
+
${element_accumulator},
|
| 59 |
+
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
|
| 60 |
+
cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
|
| 61 |
+
${stages},
|
| 62 |
+
${kernel_schedule}
|
| 63 |
+
>::CollectiveOp;
|
| 64 |
+
|
| 65 |
+
// Gemm operator ${operation_name}
|
| 66 |
+
using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
|
| 67 |
+
${problem_shape},
|
| 68 |
+
${operation_name}_mainloop,
|
| 69 |
+
${operation_name}_epilogue,
|
| 70 |
+
${tile_scheduler}>;
|
| 71 |
+
|
| 72 |
+
// Define named type
|
| 73 |
+
struct ${operation_name} :
|
| 74 |
+
public ${operation_name}_base { };
|
| 75 |
+
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
#
|
| 79 |
+
def instance_template(self):
|
| 80 |
+
return """
|
| 81 |
+
${compile_guard_start}
|
| 82 |
+
{
|
| 83 |
+
using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>;
|
| 84 |
+
manifest.append(
|
| 85 |
+
new ${gemm_kind}<GemmKernel>("${operation_name}"));
|
| 86 |
+
}
|
| 87 |
+
${compile_guard_end}
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def emit_block_scale_epilogue_functor(self, operation):
|
| 91 |
+
block_scaled_template = """
|
| 92 |
+
${epilogue_functor}<
|
| 93 |
+
${epi_vs},
|
| 94 |
+
${element_d},
|
| 95 |
+
${element_accumulator},
|
| 96 |
+
${element_sfd},
|
| 97 |
+
${layout_sfd},
|
| 98 |
+
${element_c},
|
| 99 |
+
${element_scalar}
|
| 100 |
+
>
|
| 101 |
+
"""
|
| 102 |
+
block_scaled_values = {
|
| 103 |
+
"epi_vs": str(operation.ScaleFactorVectorSize),
|
| 104 |
+
"element_d": str(DataTypeTag[operation.D.element]),
|
| 105 |
+
"element_sfd": str(DataTypeTag[operation.ScaleFactorD.element]),
|
| 106 |
+
"layout_sfd": LayoutTag[operation.ScaleFactorD.layout],
|
| 107 |
+
"epilogue_functor": EpilogueFunctor3xTag[
|
| 108 |
+
EpilogueFunctor3x.LinearCombinationBlockScaleFactor
|
| 109 |
+
],
|
| 110 |
+
"element_accumulator": str(DataTypeTag[operation.accumulator_type()]),
|
| 111 |
+
"element_scalar": str(DataTypeTag[operation.accumulator_type()]),
|
| 112 |
+
"element_c": str(DataTypeTag[operation.C.element]),
|
| 113 |
+
}
|
| 114 |
+
return SubstituteTemplate(block_scaled_template, block_scaled_values)
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
def pointerize_if_grouped(operation, layout):
|
| 118 |
+
return layout if not is_grouped(operation.gemm_kind) else layout + "* "
|
| 119 |
+
|
| 120 |
+
@staticmethod
|
| 121 |
+
def problem_shape(operation):
|
| 122 |
+
gemm_shape_type = "cute::Shape<int,int,int,int>"
|
| 123 |
+
grouped_gemm_shape_type = "cute::Shape<int,int,int>"
|
| 124 |
+
grouped_gemm_shape_type = (
|
| 125 |
+
"cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
return (
|
| 129 |
+
gemm_shape_type
|
| 130 |
+
if not is_grouped(operation.gemm_kind)
|
| 131 |
+
else grouped_gemm_shape_type
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def emit(self, operation):
|
| 135 |
+
"""Given a gem operation, emits a template definition of the operation"""
|
| 136 |
+
|
| 137 |
+
opcode_class_main = operation.tile_description.math_instruction.opcode_class
|
| 138 |
+
opcode_class_epi = opcode_class_main
|
| 139 |
+
|
| 140 |
+
tile_shape = operation.tile_description.tile_shape
|
| 141 |
+
instruction_shape = (
|
| 142 |
+
operation.tile_description.math_instruction.instruction_shape
|
| 143 |
+
)
|
| 144 |
+
cluster_m = operation.tile_description.cluster_shape[0]
|
| 145 |
+
cluster_n = operation.tile_description.cluster_shape[1]
|
| 146 |
+
|
| 147 |
+
tile_shape_m, tile_shape_n, tile_shape_k = tile_shape
|
| 148 |
+
|
| 149 |
+
# account for static/dynamic cluster shapes
|
| 150 |
+
cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0]
|
| 151 |
+
cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1]
|
| 152 |
+
|
| 153 |
+
# Shape passed to epilogue builder
|
| 154 |
+
is_sm100_kernel = operation.arch == 100
|
| 155 |
+
if is_sm100_kernel:
|
| 156 |
+
cta_m_per_mma_instruction = (
|
| 157 |
+
2 if "2sm" in operation.procedural_name() else 1
|
| 158 |
+
)
|
| 159 |
+
if cluster_m <= 0:
|
| 160 |
+
cta_m = cta_m // cta_m_per_mma_instruction
|
| 161 |
+
|
| 162 |
+
if opcode_class_main in [
|
| 163 |
+
OpcodeClass.TensorOp,
|
| 164 |
+
OpcodeClass.BlockScaledTensorOp,
|
| 165 |
+
]:
|
| 166 |
+
tile_shape_m = instruction_shape[0]
|
| 167 |
+
tile_shape_n = instruction_shape[1]
|
| 168 |
+
|
| 169 |
+
# stage count set to zero indicates builder automatic stage selection
|
| 170 |
+
if operation.tile_description.stages > 0:
|
| 171 |
+
stage_count_string = f"cutlass::gemm::collective::StageCount<\
|
| 172 |
+
{str(operation.tile_description.stages)}>"
|
| 173 |
+
else:
|
| 174 |
+
stage_count_string = (
|
| 175 |
+
f"cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(\
|
| 176 |
+
sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage))>"
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto"
|
| 180 |
+
|
| 181 |
+
(
|
| 182 |
+
instance_layout_A,
|
| 183 |
+
instance_layout_B,
|
| 184 |
+
instance_layout_C,
|
| 185 |
+
instance_layout_D,
|
| 186 |
+
) = (
|
| 187 |
+
operation.A.layout,
|
| 188 |
+
operation.B.layout,
|
| 189 |
+
operation.C.layout,
|
| 190 |
+
operation.D.layout,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# 3.0 profiler integration only supports trivial epilogues for now
|
| 194 |
+
epilogue_vector_length = 1
|
| 195 |
+
|
| 196 |
+
# Support built-in epilogue functors or user-defined functions
|
| 197 |
+
if isinstance(operation.epilogue_functor, enum.Enum):
|
| 198 |
+
values = {
|
| 199 |
+
"element_epilogue": str(DataTypeTag[operation.element_epilogue]),
|
| 200 |
+
"epilogue_functor": EpilogueFunctor3xTag[
|
| 201 |
+
operation.epilogue_functor
|
| 202 |
+
],
|
| 203 |
+
}
|
| 204 |
+
epilogue_functor = SubstituteTemplate(
|
| 205 |
+
self.builtin_epilogue_functor_template, values
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
if (
|
| 209 |
+
is_block_scaled(operation.gemm_kind)
|
| 210 |
+
and operation.ScaleFactorD.element != DataType.void
|
| 211 |
+
):
|
| 212 |
+
epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
|
| 213 |
+
else:
|
| 214 |
+
epilogue_functor = self.epilogue_functor.emit_declaration()
|
| 215 |
+
|
| 216 |
+
if (
|
| 217 |
+
is_block_scaled(operation.gemm_kind)
|
| 218 |
+
and operation.ScaleFactorD.element != DataType.void
|
| 219 |
+
):
|
| 220 |
+
epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
|
| 221 |
+
|
| 222 |
+
#
|
| 223 |
+
# Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder,
|
| 224 |
+
# e.g. cute::tuple<Element, Transform>, Transform : cute::identity / cute::conjugate.
|
| 225 |
+
element_a = (
|
| 226 |
+
DataTypeTag[operation.A.element]
|
| 227 |
+
if not operation.is_complex()
|
| 228 |
+
else f"cute::tuple<{str(DataTypeTag[operation.A.element])},\
|
| 229 |
+
{str(ComplexTransformTag3x[operation.A.complex_transform])}>"
|
| 230 |
+
)
|
| 231 |
+
element_b = (
|
| 232 |
+
DataTypeTag[operation.B.element]
|
| 233 |
+
if not operation.is_complex()
|
| 234 |
+
else f"cute::tuple<{str(DataTypeTag[operation.B.element])},\
|
| 235 |
+
{str(ComplexTransformTag3x[operation.B.complex_transform])}>"
|
| 236 |
+
)
|
| 237 |
+
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
|
| 238 |
+
|
| 239 |
+
if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
|
| 240 |
+
is_no_smem_epilogue = operation.epilogue_schedule in [
|
| 241 |
+
EpilogueScheduleType.NoSmemWarpSpecialized1Sm,
|
| 242 |
+
EpilogueScheduleType.NoSmemWarpSpecialized2Sm,
|
| 243 |
+
]
|
| 244 |
+
grouped = is_grouped(operation.gemm_kind)
|
| 245 |
+
if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(
|
| 246 |
+
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped
|
| 247 |
+
):
|
| 248 |
+
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
| 249 |
+
if not is_no_smem_epilogue:
|
| 250 |
+
epilogue_schedule_type = EpilogueScheduleTag[
|
| 251 |
+
to_grouped_schedule(
|
| 252 |
+
EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped
|
| 253 |
+
)
|
| 254 |
+
]
|
| 255 |
+
if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(
|
| 256 |
+
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped
|
| 257 |
+
):
|
| 258 |
+
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
| 259 |
+
if not is_no_smem_epilogue:
|
| 260 |
+
epilogue_schedule_type = EpilogueScheduleTag[
|
| 261 |
+
to_grouped_schedule(
|
| 262 |
+
EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped
|
| 263 |
+
)
|
| 264 |
+
]
|
| 265 |
+
element_a = f"cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>"
|
| 266 |
+
element_b = f"cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>"
|
| 267 |
+
|
| 268 |
+
operation_name_str = operation.procedural_name()
|
| 269 |
+
layout_a_str = LayoutTag[instance_layout_A]
|
| 270 |
+
layout_b_str = LayoutTag[instance_layout_B]
|
| 271 |
+
mixed_dtype_prepare_code = ""
|
| 272 |
+
if operation.mixed_input_mode is not None:
|
| 273 |
+
A_dtype = operation.A.element
|
| 274 |
+
B_dtype = operation.B.element
|
| 275 |
+
A_dtype_bits = DataTypeSize[A_dtype]
|
| 276 |
+
B_dtype_bits = DataTypeSize[B_dtype]
|
| 277 |
+
is_A_dtype_narrow = A_dtype_bits < B_dtype_bits
|
| 278 |
+
if is_A_dtype_narrow:
|
| 279 |
+
narrow_dtype, wide_dtype = (A_dtype, B_dtype)
|
| 280 |
+
narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits)
|
| 281 |
+
else:
|
| 282 |
+
narrow_dtype, wide_dtype = (B_dtype, A_dtype)
|
| 283 |
+
narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits)
|
| 284 |
+
|
| 285 |
+
narrow_tag = DataTypeTag[narrow_dtype]
|
| 286 |
+
wide_tag = DataTypeTag[wide_dtype]
|
| 287 |
+
scale_tag = DataTypeTag[wide_dtype]
|
| 288 |
+
zero_tag = DataTypeTag[wide_dtype]
|
| 289 |
+
|
| 290 |
+
do_shuffle = False
|
| 291 |
+
value_shuffle_str = ""
|
| 292 |
+
if narrow_dtype_bits == 4 and wide_dtype_bits == 16:
|
| 293 |
+
value_shuffle_str = "cute::Layout<cute::Shape<cute::_2,cute::_4>, \
|
| 294 |
+
cute::Stride<cute::_4,cute::_1>>"
|
| 295 |
+
do_shuffle = True
|
| 296 |
+
if narrow_dtype_bits == 8 and wide_dtype_bits == 16:
|
| 297 |
+
value_shuffle_str = "cute::Layout<cute::Shape<cute::_2,cute::_2>, \
|
| 298 |
+
cute::Stride<cute::_2,cute::_1>>"
|
| 299 |
+
do_shuffle = True
|
| 300 |
+
do_shuffle = operation.mixed_input_shuffle and do_shuffle
|
| 301 |
+
|
| 302 |
+
if do_shuffle:
|
| 303 |
+
if is_A_dtype_narrow:
|
| 304 |
+
stride_narrow_str = (
|
| 305 |
+
f"cutlass::detail::TagToStrideA_t<{layout_a_str}>"
|
| 306 |
+
)
|
| 307 |
+
layout_a_str = f"{operation_name_str}_LayoutNarrowReordered"
|
| 308 |
+
else:
|
| 309 |
+
stride_narrow_str = (
|
| 310 |
+
f"cutlass::detail::TagToStrideB_t<{layout_b_str}>"
|
| 311 |
+
)
|
| 312 |
+
layout_b_str = f"{operation_name_str}_LayoutNarrowReordered"
|
| 313 |
+
# The {operation_name_str}_ prefixs in mixed_dtype_prepare_code and
|
| 314 |
+
# layout_{a, b}_str are to prevent errors in Windows platform unity build
|
| 315 |
+
mixed_dtype_prepare_code = f"""
|
| 316 |
+
using {operation_name_str}_StrideNarrow = {stride_narrow_str};
|
| 317 |
+
using {operation_name_str}_ValueShuffle = {value_shuffle_str};
|
| 318 |
+
static constexpr int {operation_name_str}_NumShuffleAtoms = 1;
|
| 319 |
+
using {operation_name_str}_MmaAtomShape = \
|
| 320 |
+
cute::Layout<cute::Shape<cute::_1, cute::Int<{operation_name_str}_NumShuffleAtoms>>>;
|
| 321 |
+
using {operation_name_str}_LayoutAtomQuant = \
|
| 322 |
+
decltype(cutlass::compute_memory_reordering_atom<{wide_tag}, {operation_name_str}_MmaAtomShape, \
|
| 323 |
+
{operation_name_str}_ValueShuffle>());
|
| 324 |
+
using {operation_name_str}_LayoutNarrowReordered = \
|
| 325 |
+
decltype(cute::tile_to_shape({operation_name_str}_LayoutAtomQuant{{}}, \
|
| 326 |
+
cute::Layout<cute::Shape<int,int,int>, {operation_name_str}_StrideNarrow>{{}}));
|
| 327 |
+
"""
|
| 328 |
+
|
| 329 |
+
mixed_input_modes_to_element = {
|
| 330 |
+
MixedInputMode.ConvertOnly: narrow_tag,
|
| 331 |
+
MixedInputMode.ScaleOnly: f"cute::tuple<{narrow_tag}, {scale_tag}>",
|
| 332 |
+
MixedInputMode.ScaleWithZeroPoint: f"cute::tuple<{narrow_tag}, {scale_tag}, {zero_tag}>",
|
| 333 |
+
}
|
| 334 |
+
narrow_element = mixed_input_modes_to_element.get(
|
| 335 |
+
operation.mixed_input_mode, narrow_tag
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
if narrow_dtype == DataType.s4 and (
|
| 339 |
+
wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2
|
| 340 |
+
):
|
| 341 |
+
narrow_element = (
|
| 342 |
+
f"cute::tuple<{narrow_tag}, cutlass::Array<{scale_tag}, 8>>"
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
if is_A_dtype_narrow:
|
| 346 |
+
element_a = narrow_element
|
| 347 |
+
else:
|
| 348 |
+
element_b = narrow_element
|
| 349 |
+
|
| 350 |
+
if self.evt_name:
|
| 351 |
+
epilogue_functor = self.evt_name
|
| 352 |
+
|
| 353 |
+
values = {
|
| 354 |
+
"operation_name": operation_name_str,
|
| 355 |
+
"operation_suffix": self.operation_suffix,
|
| 356 |
+
"problem_shape": self.problem_shape(operation),
|
| 357 |
+
"element_a": element_a,
|
| 358 |
+
"layout_a": self.pointerize_if_grouped(operation, layout_a_str),
|
| 359 |
+
"element_b": element_b,
|
| 360 |
+
"layout_b": self.pointerize_if_grouped(operation, layout_b_str),
|
| 361 |
+
"element_c": DataTypeTag[operation.C.element],
|
| 362 |
+
"layout_c": self.pointerize_if_grouped(
|
| 363 |
+
operation, LayoutTag[instance_layout_C]
|
| 364 |
+
),
|
| 365 |
+
"element_d": DataTypeTag[operation.D.element],
|
| 366 |
+
"layout_d": self.pointerize_if_grouped(
|
| 367 |
+
operation, LayoutTag[instance_layout_D]
|
| 368 |
+
),
|
| 369 |
+
"element_accumulator": DataTypeTag[operation.accumulator_type()],
|
| 370 |
+
"opcode_class_main": OpcodeClassTag[opcode_class_main],
|
| 371 |
+
"opcode_class_epi": OpcodeClassTag[opcode_class_epi],
|
| 372 |
+
"arch": f"cutlass::arch::Sm{operation.arch}",
|
| 373 |
+
"tile_shape_m": str(tile_shape_m),
|
| 374 |
+
"tile_shape_n": str(tile_shape_n),
|
| 375 |
+
"tile_shape_k": str(tile_shape_k),
|
| 376 |
+
"cluster_shape_m": "cute::_"
|
| 377 |
+
+ str(operation.tile_description.cluster_shape[0])
|
| 378 |
+
if operation.tile_description.cluster_shape[0] > 0
|
| 379 |
+
else "int",
|
| 380 |
+
"cluster_shape_n": "cute::_"
|
| 381 |
+
+ str(operation.tile_description.cluster_shape[1])
|
| 382 |
+
if operation.tile_description.cluster_shape[1] > 0
|
| 383 |
+
else "int",
|
| 384 |
+
"cluster_shape_k": "cute::_"
|
| 385 |
+
+ str(operation.tile_description.cluster_shape[2])
|
| 386 |
+
if operation.tile_description.cluster_shape[2] > 0
|
| 387 |
+
else "int",
|
| 388 |
+
"instruction_shape_m": str(instruction_shape[0]),
|
| 389 |
+
"instruction_shape_n": str(instruction_shape[1]),
|
| 390 |
+
"instruction_shape_k": str(instruction_shape[2]),
|
| 391 |
+
"kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]),
|
| 392 |
+
"epilogue_schedule": str(epilogue_schedule_type),
|
| 393 |
+
"epi_tile_mn": epi_tile_mn,
|
| 394 |
+
"epilogue_functor": epilogue_functor,
|
| 395 |
+
"stages": stage_count_string,
|
| 396 |
+
"align_a": str(operation.A.alignment),
|
| 397 |
+
"align_b": str(operation.B.alignment),
|
| 398 |
+
"align_c": str(operation.C.alignment),
|
| 399 |
+
"align_d": str(operation.C.alignment),
|
| 400 |
+
"transform_a": ComplexTransformTag[operation.A.complex_transform],
|
| 401 |
+
"transform_b": ComplexTransformTag[operation.B.complex_transform],
|
| 402 |
+
"math_operation": MathOperationTag[
|
| 403 |
+
operation.tile_description.math_instruction.math_operation
|
| 404 |
+
],
|
| 405 |
+
"epilogue_vector_length": str(epilogue_vector_length),
|
| 406 |
+
"element_epilogue": str(DataTypeTag[operation.element_epilogue]),
|
| 407 |
+
"tile_scheduler": str(TileSchedulerTag[operation.tile_scheduler]),
|
| 408 |
+
"mixed_dtype_prepare_code": mixed_dtype_prepare_code,
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
return SubstituteTemplate(self.gemm_template, values)
|